xref: /aosp_15_r20/external/pytorch/torch/autograd/gradcheck.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport collections
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport warnings
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
7*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.testing
11*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import _vmap, vmap
12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import is_tensor_like
13*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _TensorOrTensors
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
17*da0073e9SAndroid Build Coastguard Worker# since they have been exposed from before we added `__all__`  and we already maintain BC for them
18*da0073e9SAndroid Build Coastguard Worker# We should eventually deprecate them and remove them from `__all__`
19*da0073e9SAndroid Build Coastguard Worker__all__ = [
20*da0073e9SAndroid Build Coastguard Worker    "gradcheck",
21*da0073e9SAndroid Build Coastguard Worker    "gradgradcheck",
22*da0073e9SAndroid Build Coastguard Worker    "GradcheckError",
23*da0073e9SAndroid Build Coastguard Worker    "get_numerical_jacobian",
24*da0073e9SAndroid Build Coastguard Worker    "get_analytical_jacobian",
25*da0073e9SAndroid Build Coastguard Worker    "get_numerical_jacobian_wrt_specific_input",
26*da0073e9SAndroid Build Coastguard Worker]
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerclass GradcheckError(RuntimeError):
30*da0073e9SAndroid Build Coastguard Worker    r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`."""
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerdef _is_sparse_compressed_tensor(obj: torch.Tensor):
34*da0073e9SAndroid Build Coastguard Worker    return obj.layout in {
35*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csr,
36*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csc,
37*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsr,
38*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsc,
39*da0073e9SAndroid Build Coastguard Worker    }
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdef _is_sparse_any_tensor(obj: torch.Tensor):
43*da0073e9SAndroid Build Coastguard Worker    return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef _is_float_or_complex_tensor(obj):
47*da0073e9SAndroid Build Coastguard Worker    return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex())
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerdef _allocate_jacobians_with_inputs(
51*da0073e9SAndroid Build Coastguard Worker    input_tensors: Tuple, numel_output
52*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
53*da0073e9SAndroid Build Coastguard Worker    # Makes zero-filled tensors from inputs. If `numel_output` is not None, for
54*da0073e9SAndroid Build Coastguard Worker    # each tensor in `input_tensors`, returns a new zero-filled tensor with height
55*da0073e9SAndroid Build Coastguard Worker    # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
56*da0073e9SAndroid Build Coastguard Worker    # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
57*da0073e9SAndroid Build Coastguard Worker    # the same dtype and device as those of the corresponding input.
58*da0073e9SAndroid Build Coastguard Worker    out: List[torch.Tensor] = []
59*da0073e9SAndroid Build Coastguard Worker    for t in input_tensors:
60*da0073e9SAndroid Build Coastguard Worker        if _is_float_or_complex_tensor(t) and t.requires_grad:
61*da0073e9SAndroid Build Coastguard Worker            out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
62*da0073e9SAndroid Build Coastguard Worker    return tuple(out)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Workerdef _allocate_jacobians_with_outputs(
66*da0073e9SAndroid Build Coastguard Worker    output_tensors: Tuple, numel_input, dtype=None, device=None
67*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
68*da0073e9SAndroid Build Coastguard Worker    # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
69*da0073e9SAndroid Build Coastguard Worker    # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
70*da0073e9SAndroid Build Coastguard Worker    # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
71*da0073e9SAndroid Build Coastguard Worker    # (t.numel,).
72*da0073e9SAndroid Build Coastguard Worker    out: List[torch.Tensor] = []
73*da0073e9SAndroid Build Coastguard Worker    options = {"dtype": dtype, "device": device, "layout": torch.strided}
74*da0073e9SAndroid Build Coastguard Worker    for t in output_tensors:
75*da0073e9SAndroid Build Coastguard Worker        if _is_float_or_complex_tensor(t):
76*da0073e9SAndroid Build Coastguard Worker            out.append(t.new_zeros((numel_input, t.numel()), **options))
77*da0073e9SAndroid Build Coastguard Worker    return tuple(out)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef _iter_tensors(
81*da0073e9SAndroid Build Coastguard Worker    x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False
82*da0073e9SAndroid Build Coastguard Worker) -> Iterable[torch.Tensor]:
83*da0073e9SAndroid Build Coastguard Worker    if is_tensor_like(x):
84*da0073e9SAndroid Build Coastguard Worker        # mypy doesn't narrow type of `x` to torch.Tensor
85*da0073e9SAndroid Build Coastguard Worker        if x.requires_grad or not only_requiring_grad:  # type: ignore[union-attr]
86*da0073e9SAndroid Build Coastguard Worker            yield x  # type: ignore[misc]
87*da0073e9SAndroid Build Coastguard Worker    elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
88*da0073e9SAndroid Build Coastguard Worker        for elem in x:
89*da0073e9SAndroid Build Coastguard Worker            yield from _iter_tensors(elem, only_requiring_grad)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerdef _densify(x):
93*da0073e9SAndroid Build Coastguard Worker    # return a copy of sparse x with all unspecified elements
94*da0073e9SAndroid Build Coastguard Worker    # "replaced" with zero-valued elements
95*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, (list, tuple)):
96*da0073e9SAndroid Build Coastguard Worker        return type(x)(map(_densify, x))
97*da0073e9SAndroid Build Coastguard Worker    elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}:  # type: ignore[attr-defined] # no attr _mkldnn
98*da0073e9SAndroid Build Coastguard Worker        return x
99*da0073e9SAndroid Build Coastguard Worker    elif x.layout is torch.sparse_coo:
100*da0073e9SAndroid Build Coastguard Worker        device = x.device
101*da0073e9SAndroid Build Coastguard Worker        indices_dtype = x._indices().dtype
102*da0073e9SAndroid Build Coastguard Worker        tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device)
103*da0073e9SAndroid Build Coastguard Worker        indices = tmp.nonzero().t().to(dtype=indices_dtype)
104*da0073e9SAndroid Build Coastguard Worker        values = torch.zeros(
105*da0073e9SAndroid Build Coastguard Worker            (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device
106*da0073e9SAndroid Build Coastguard Worker        )
107*da0073e9SAndroid Build Coastguard Worker        x_coalesced = x.detach().coalesce()
108*da0073e9SAndroid Build Coastguard Worker        if x_coalesced.numel() > 0:
109*da0073e9SAndroid Build Coastguard Worker            stride = tmp.stride()
110*da0073e9SAndroid Build Coastguard Worker            flat_indices = (
111*da0073e9SAndroid Build Coastguard Worker                x_coalesced.indices()
112*da0073e9SAndroid Build Coastguard Worker                .mul(
113*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze(
114*da0073e9SAndroid Build Coastguard Worker                        1
115*da0073e9SAndroid Build Coastguard Worker                    )
116*da0073e9SAndroid Build Coastguard Worker                )
117*da0073e9SAndroid Build Coastguard Worker                .sum(0)
118*da0073e9SAndroid Build Coastguard Worker            )
119*da0073e9SAndroid Build Coastguard Worker            values[flat_indices] = x_coalesced.values()
120*da0073e9SAndroid Build Coastguard Worker        return (
121*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, x.shape)
122*da0073e9SAndroid Build Coastguard Worker            ._coalesced_(True)
123*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(x.requires_grad)
124*da0073e9SAndroid Build Coastguard Worker        )
125*da0073e9SAndroid Build Coastguard Worker    elif _is_sparse_compressed_tensor(x):
126*da0073e9SAndroid Build Coastguard Worker        blocksize = (
127*da0073e9SAndroid Build Coastguard Worker            x.values().shape[1:3]
128*da0073e9SAndroid Build Coastguard Worker            if x.layout in {torch.sparse_bsr, torch.sparse_bsc}
129*da0073e9SAndroid Build Coastguard Worker            else None
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker        compressed_indices = (
132*da0073e9SAndroid Build Coastguard Worker            x.crow_indices()
133*da0073e9SAndroid Build Coastguard Worker            if x.layout in {torch.sparse_csr, torch.sparse_bsr}
134*da0073e9SAndroid Build Coastguard Worker            else x.ccol_indices()
135*da0073e9SAndroid Build Coastguard Worker        )
136*da0073e9SAndroid Build Coastguard Worker        # We'll use intermediate sparse COO for simplicity
137*da0073e9SAndroid Build Coastguard Worker        r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse(
138*da0073e9SAndroid Build Coastguard Worker            layout=x.layout, blocksize=blocksize
139*da0073e9SAndroid Build Coastguard Worker        )
140*da0073e9SAndroid Build Coastguard Worker        # Check that all elements are specified also after `to_sparse` op:
141*da0073e9SAndroid Build Coastguard Worker        dense_numel = r.values().numel() // max(1, r.values().shape[0])
142*da0073e9SAndroid Build Coastguard Worker        batch_numel = compressed_indices.numel() // compressed_indices.shape[-1]
143*da0073e9SAndroid Build Coastguard Worker        sparse_numel = r.numel() // max(1, dense_numel * batch_numel)
144*da0073e9SAndroid Build Coastguard Worker        if sparse_numel != r._nnz():
145*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
146*da0073e9SAndroid Build Coastguard Worker                f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}"
147*da0073e9SAndroid Build Coastguard Worker            )
148*da0073e9SAndroid Build Coastguard Worker        return r.requires_grad_(x.requires_grad)
149*da0073e9SAndroid Build Coastguard Worker    elif _is_sparse_any_tensor(x):
150*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(x.layout)
151*da0073e9SAndroid Build Coastguard Worker    return x
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Workerdef _iter_tensor(x_tensor):
155*da0073e9SAndroid Build Coastguard Worker    # (Only used for slow gradcheck) Returns a generator that yields the following
156*da0073e9SAndroid Build Coastguard Worker    # elements at each iteration:
157*da0073e9SAndroid Build Coastguard Worker    #  1) a tensor: the same tensor is returned across all iterations. The tensor
158*da0073e9SAndroid Build Coastguard Worker    #     is not the same as the original x_tensor as given as input - it is
159*da0073e9SAndroid Build Coastguard Worker    #     prepared so that it can be modified in-place. Depending on whether the
160*da0073e9SAndroid Build Coastguard Worker    #     input tensor is strided, sparse, or dense, the returned tensor may or may
161*da0073e9SAndroid Build Coastguard Worker    #     not share storage with x_tensor.
162*da0073e9SAndroid Build Coastguard Worker    #  2) a tuple of indices that can be used with advanced indexing (yielded in
163*da0073e9SAndroid Build Coastguard Worker    #     dictionary order)
164*da0073e9SAndroid Build Coastguard Worker    #  3) flattened index that will be used to index into the Jacobian tensor
165*da0073e9SAndroid Build Coastguard Worker    #
166*da0073e9SAndroid Build Coastguard Worker    # For a tensor t with size (2, 2), _iter_tensor yields:
167*da0073e9SAndroid Build Coastguard Worker    #     `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3`
168*da0073e9SAndroid Build Coastguard Worker    #
169*da0073e9SAndroid Build Coastguard Worker    # where x is the t.data of the original tensor. Perturbing the entry of x
170*da0073e9SAndroid Build Coastguard Worker    # at index (1, 1) yields the 3rd column of the overall Jacobian matrix.
171*da0073e9SAndroid Build Coastguard Worker    if _is_sparse_any_tensor(x_tensor):
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        def get_stride(size):
174*da0073e9SAndroid Build Coastguard Worker            dim = len(size)
175*da0073e9SAndroid Build Coastguard Worker            tmp = 1
176*da0073e9SAndroid Build Coastguard Worker            stride = [0] * dim
177*da0073e9SAndroid Build Coastguard Worker            for i in reversed(range(dim)):
178*da0073e9SAndroid Build Coastguard Worker                stride[i] = tmp
179*da0073e9SAndroid Build Coastguard Worker                tmp *= size[i]
180*da0073e9SAndroid Build Coastguard Worker            return stride
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        x_nnz = x_tensor._nnz()
183*da0073e9SAndroid Build Coastguard Worker        x_size = list(x_tensor.size())
184*da0073e9SAndroid Build Coastguard Worker        if x_tensor.layout is torch.sparse_coo:
185*da0073e9SAndroid Build Coastguard Worker            x_indices = x_tensor._indices().t()
186*da0073e9SAndroid Build Coastguard Worker            x_values = x_tensor._values()
187*da0073e9SAndroid Build Coastguard Worker        elif x_tensor.layout is torch.sparse_csr:
188*da0073e9SAndroid Build Coastguard Worker            x_indices = torch._convert_indices_from_csr_to_coo(
189*da0073e9SAndroid Build Coastguard Worker                x_tensor.crow_indices(), x_tensor.col_indices()
190*da0073e9SAndroid Build Coastguard Worker            ).t()
191*da0073e9SAndroid Build Coastguard Worker            x_values = x_tensor.values()
192*da0073e9SAndroid Build Coastguard Worker        elif x_tensor.layout is torch.sparse_csc:
193*da0073e9SAndroid Build Coastguard Worker            x_indices = torch._convert_indices_from_csr_to_coo(
194*da0073e9SAndroid Build Coastguard Worker                x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
195*da0073e9SAndroid Build Coastguard Worker            ).t()
196*da0073e9SAndroid Build Coastguard Worker            x_values = x_tensor.values()
197*da0073e9SAndroid Build Coastguard Worker        elif x_tensor.layout is torch.sparse_bsr:
198*da0073e9SAndroid Build Coastguard Worker            x_block_values = x_tensor.values()
199*da0073e9SAndroid Build Coastguard Worker            x_blocksize = x_block_values.size()[1:3]
200*da0073e9SAndroid Build Coastguard Worker            x_indices = (
201*da0073e9SAndroid Build Coastguard Worker                torch._convert_indices_from_csr_to_coo(
202*da0073e9SAndroid Build Coastguard Worker                    x_tensor.crow_indices(), x_tensor.col_indices()
203*da0073e9SAndroid Build Coastguard Worker                )
204*da0073e9SAndroid Build Coastguard Worker                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
205*da0073e9SAndroid Build Coastguard Worker                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
206*da0073e9SAndroid Build Coastguard Worker                .add_(
207*da0073e9SAndroid Build Coastguard Worker                    torch.stack(
208*da0073e9SAndroid Build Coastguard Worker                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
209*da0073e9SAndroid Build Coastguard Worker                    ).repeat(1, x_nnz)
210*da0073e9SAndroid Build Coastguard Worker                )
211*da0073e9SAndroid Build Coastguard Worker                .t()
212*da0073e9SAndroid Build Coastguard Worker            )
213*da0073e9SAndroid Build Coastguard Worker            x_values = x_block_values.flatten(0, 2)
214*da0073e9SAndroid Build Coastguard Worker            x_nnz = x_values.size(0)
215*da0073e9SAndroid Build Coastguard Worker        elif x_tensor.layout is torch.sparse_bsc:
216*da0073e9SAndroid Build Coastguard Worker            x_block_values = x_tensor.values()
217*da0073e9SAndroid Build Coastguard Worker            x_blocksize = x_block_values.size()[1:3]
218*da0073e9SAndroid Build Coastguard Worker            x_indices = (
219*da0073e9SAndroid Build Coastguard Worker                torch._convert_indices_from_csr_to_coo(
220*da0073e9SAndroid Build Coastguard Worker                    x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
221*da0073e9SAndroid Build Coastguard Worker                )
222*da0073e9SAndroid Build Coastguard Worker                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
223*da0073e9SAndroid Build Coastguard Worker                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
224*da0073e9SAndroid Build Coastguard Worker                .add_(
225*da0073e9SAndroid Build Coastguard Worker                    torch.stack(
226*da0073e9SAndroid Build Coastguard Worker                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
227*da0073e9SAndroid Build Coastguard Worker                    ).repeat(1, x_nnz)
228*da0073e9SAndroid Build Coastguard Worker                )
229*da0073e9SAndroid Build Coastguard Worker                .t()
230*da0073e9SAndroid Build Coastguard Worker            )
231*da0073e9SAndroid Build Coastguard Worker            x_values = x_block_values.flatten(0, 2)
232*da0073e9SAndroid Build Coastguard Worker            x_nnz = x_values.size(0)
233*da0073e9SAndroid Build Coastguard Worker        else:
234*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input")
235*da0073e9SAndroid Build Coastguard Worker        x_stride = get_stride(x_size)
236*da0073e9SAndroid Build Coastguard Worker        # Use .data here to get around the version check
237*da0073e9SAndroid Build Coastguard Worker        x_values = x_values.data
238*da0073e9SAndroid Build Coastguard Worker        for i in range(x_nnz):
239*da0073e9SAndroid Build Coastguard Worker            x_value = x_values[i]
240*da0073e9SAndroid Build Coastguard Worker            for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
241*da0073e9SAndroid Build Coastguard Worker                indices = x_indices[i].tolist() + list(x_idx)
242*da0073e9SAndroid Build Coastguard Worker                d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
243*da0073e9SAndroid Build Coastguard Worker                yield x_value, x_idx, d_idx
244*da0073e9SAndroid Build Coastguard Worker    elif x_tensor.layout == torch._mkldnn:  # type: ignore[attr-defined]
245*da0073e9SAndroid Build Coastguard Worker        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
246*da0073e9SAndroid Build Coastguard Worker            # this is really inefficient, but without indexing implemented, there's
247*da0073e9SAndroid Build Coastguard Worker            # not really a better way than converting back and forth
248*da0073e9SAndroid Build Coastguard Worker            x_tensor_dense = x_tensor.to_dense()
249*da0073e9SAndroid Build Coastguard Worker            yield x_tensor_dense, x_idx, d_idx
250*da0073e9SAndroid Build Coastguard Worker    else:
251*da0073e9SAndroid Build Coastguard Worker        # Use .data here to get around the version check
252*da0073e9SAndroid Build Coastguard Worker        x_tensor = x_tensor.data
253*da0073e9SAndroid Build Coastguard Worker        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
254*da0073e9SAndroid Build Coastguard Worker            yield x_tensor, x_idx, d_idx
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jacobian(
258*da0073e9SAndroid Build Coastguard Worker    fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False
259*da0073e9SAndroid Build Coastguard Worker) -> List[Tuple[torch.Tensor, ...]]:
260*da0073e9SAndroid Build Coastguard Worker    """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`.
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    If not specified, targets are the input. Returns M * N Jacobians where N is the
263*da0073e9SAndroid Build Coastguard Worker    number of tensors in target that require grad and M is the number of non-integral
264*da0073e9SAndroid Build Coastguard Worker    outputs.
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    Args:
267*da0073e9SAndroid Build Coastguard Worker        fn: the function to compute the jacobian for
268*da0073e9SAndroid Build Coastguard Worker        inputs: inputs to `fn`
269*da0073e9SAndroid Build Coastguard Worker        outputs: provide precomputed outputs to avoid one extra invocation of fn
270*da0073e9SAndroid Build Coastguard Worker        target: the Tensors wrt whom Jacobians are calculated (default=`inputs`)
271*da0073e9SAndroid Build Coastguard Worker        eps: the magnitude of the perturbation during finite differencing
272*da0073e9SAndroid Build Coastguard Worker             (default=`1e-3`)
273*da0073e9SAndroid Build Coastguard Worker        is_forward_ad: if this numerical jacobian is computed to be checked wrt
274*da0073e9SAndroid Build Coastguard Worker                       forward AD gradients (this is used for error checking only)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    Returns:
277*da0073e9SAndroid Build Coastguard Worker        A list of M N-tuples of tensors
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    Note that `target` may not even be part of `input` to `fn`, so please be
280*da0073e9SAndroid Build Coastguard Worker    **very careful** in this to not clone `target`.
281*da0073e9SAndroid Build Coastguard Worker    """
282*da0073e9SAndroid Build Coastguard Worker    jacobians: List[Tuple[torch.Tensor, ...]] = []
283*da0073e9SAndroid Build Coastguard Worker    if outputs is None:
284*da0073e9SAndroid Build Coastguard Worker        outputs = _as_tuple(fn(*_as_tuple(inputs)))
285*da0073e9SAndroid Build Coastguard Worker    if not is_forward_ad and any(o.is_complex() for o in outputs):
286*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
287*da0073e9SAndroid Build Coastguard Worker            "Expected output to be non-complex. get_numerical_jacobian no "
288*da0073e9SAndroid Build Coastguard Worker            "longer supports functions that return complex outputs."
289*da0073e9SAndroid Build Coastguard Worker        )
290*da0073e9SAndroid Build Coastguard Worker    if target is None:
291*da0073e9SAndroid Build Coastguard Worker        target = inputs
292*da0073e9SAndroid Build Coastguard Worker    inp_indices = [
293*da0073e9SAndroid Build Coastguard Worker        i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
294*da0073e9SAndroid Build Coastguard Worker    ]
295*da0073e9SAndroid Build Coastguard Worker    for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
296*da0073e9SAndroid Build Coastguard Worker        jacobians += [
297*da0073e9SAndroid Build Coastguard Worker            get_numerical_jacobian_wrt_specific_input(
298*da0073e9SAndroid Build Coastguard Worker                fn,
299*da0073e9SAndroid Build Coastguard Worker                inp_idx,
300*da0073e9SAndroid Build Coastguard Worker                inputs,
301*da0073e9SAndroid Build Coastguard Worker                outputs,
302*da0073e9SAndroid Build Coastguard Worker                eps,
303*da0073e9SAndroid Build Coastguard Worker                input=inp,
304*da0073e9SAndroid Build Coastguard Worker                is_forward_ad=is_forward_ad,
305*da0073e9SAndroid Build Coastguard Worker            )
306*da0073e9SAndroid Build Coastguard Worker        ]
307*da0073e9SAndroid Build Coastguard Worker    return jacobians
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker@deprecated(
311*da0073e9SAndroid Build Coastguard Worker    "`get_numerical_jacobian` was part of PyTorch's private API and not "
312*da0073e9SAndroid Build Coastguard Worker    "meant to be exposed. We are deprecating it and it will be removed "
313*da0073e9SAndroid Build Coastguard Worker    "in a future version of PyTorch. If you have a specific use for "
314*da0073e9SAndroid Build Coastguard Worker    "this or feature request for this to be a stable API, please file "
315*da0073e9SAndroid Build Coastguard Worker    "us an issue at https://github.com/pytorch/pytorch/issues/new",
316*da0073e9SAndroid Build Coastguard Worker    category=FutureWarning,
317*da0073e9SAndroid Build Coastguard Worker)
318*da0073e9SAndroid Build Coastguard Workerdef get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0):
319*da0073e9SAndroid Build Coastguard Worker    """Compute the numerical Jacobian for a given fn and its inputs.
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    This is a Deprecated API.
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    Args:
324*da0073e9SAndroid Build Coastguard Worker        fn: the function to compute the Jacobian for (must take inputs as a tuple)
325*da0073e9SAndroid Build Coastguard Worker        input: input to `fn`
326*da0073e9SAndroid Build Coastguard Worker        target: the Tensors wrt whom Jacobians are calculated (default=`input`)
327*da0073e9SAndroid Build Coastguard Worker        eps: the magnitude of the perturbation during finite differencing
328*da0073e9SAndroid Build Coastguard Worker             (default=`1e-3`)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker    Returns:
331*da0073e9SAndroid Build Coastguard Worker        A list of Jacobians of `fn` (restricted to its first output) with respect to
332*da0073e9SAndroid Build Coastguard Worker        each input or target, if provided.
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    Note that `target` may not even be part of `input` to `fn`, so please be
335*da0073e9SAndroid Build Coastguard Worker    **very careful** in this to not clone `target`.
336*da0073e9SAndroid Build Coastguard Worker    """
337*da0073e9SAndroid Build Coastguard Worker    if (
338*da0073e9SAndroid Build Coastguard Worker        grad_out != 1.0
339*da0073e9SAndroid Build Coastguard Worker    ):  # grad_out param is only kept for backward compatibility reasons
340*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
341*da0073e9SAndroid Build Coastguard Worker            "Expected grad_out to be 1.0. get_numerical_jacobian no longer "
342*da0073e9SAndroid Build Coastguard Worker            "supports values of grad_out != 1.0."
343*da0073e9SAndroid Build Coastguard Worker        )
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    def fn_pack_inps(*inps):
346*da0073e9SAndroid Build Coastguard Worker        return fn(inps)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Workerdef _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
354*da0073e9SAndroid Build Coastguard Worker    # Computes numerical directional derivative as finite difference
355*da0073e9SAndroid Build Coastguard Worker    # of function `fn` at input `entry`, perturbed by vector `v`.
356*da0073e9SAndroid Build Coastguard Worker    if _is_sparse_compressed_tensor(entry):
357*da0073e9SAndroid Build Coastguard Worker        # sparse compressed tensors don't implement sub/add/copy_
358*da0073e9SAndroid Build Coastguard Worker        # yet. However, in non-masked semantics context entry and v
359*da0073e9SAndroid Build Coastguard Worker        # have the same sparse indices ...
360*da0073e9SAndroid Build Coastguard Worker        assert entry.layout == v.layout, (entry.layout, v.layout)
361*da0073e9SAndroid Build Coastguard Worker        assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape)
362*da0073e9SAndroid Build Coastguard Worker        # ... the finite differencing can be performed on values only:
363*da0073e9SAndroid Build Coastguard Worker        entry = entry.values()
364*da0073e9SAndroid Build Coastguard Worker        v = v.values()
365*da0073e9SAndroid Build Coastguard Worker        # we'll detach to avoid backward computations that sparse
366*da0073e9SAndroid Build Coastguard Worker        # tensors have limited support for.
367*da0073e9SAndroid Build Coastguard Worker        entry = entry.detach()
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    orig = entry.clone()
370*da0073e9SAndroid Build Coastguard Worker    entry.copy_(orig - v)
371*da0073e9SAndroid Build Coastguard Worker    outa = fn()
372*da0073e9SAndroid Build Coastguard Worker    entry.copy_(orig + v)
373*da0073e9SAndroid Build Coastguard Worker    outb = fn()
374*da0073e9SAndroid Build Coastguard Worker    entry.copy_(orig)
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    def compute(a, b):
377*da0073e9SAndroid Build Coastguard Worker        nbhd_checks_fn(a, b)
378*da0073e9SAndroid Build Coastguard Worker        ret = (b - a) / (2 * norm_v)  # use central difference approx
379*da0073e9SAndroid Build Coastguard Worker        return ret.detach().reshape(-1)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    return tuple(compute(a, b) for (a, b) in zip(outa, outb))
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Workerdef _compute_numerical_jvps_wrt_specific_input(
385*da0073e9SAndroid Build Coastguard Worker    jvp_fn, delta, input_is_complex, is_forward_ad=False
386*da0073e9SAndroid Build Coastguard Worker) -> List[torch.Tensor]:
387*da0073e9SAndroid Build Coastguard Worker    # Computing the jacobian only works for real delta
388*da0073e9SAndroid Build Coastguard Worker    # For details on the algorithm used here, refer:
389*da0073e9SAndroid Build Coastguard Worker    # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
390*da0073e9SAndroid Build Coastguard Worker    # s = fn(z) where z = x for real valued input
391*da0073e9SAndroid Build Coastguard Worker    # and z = x + yj for complex valued input
392*da0073e9SAndroid Build Coastguard Worker    jvps: List[torch.Tensor] = []
393*da0073e9SAndroid Build Coastguard Worker    ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    if input_is_complex:  # C -> R
396*da0073e9SAndroid Build Coastguard Worker        ds_dy_tup = (
397*da0073e9SAndroid Build Coastguard Worker            jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j)
398*da0073e9SAndroid Build Coastguard Worker        )
399*da0073e9SAndroid Build Coastguard Worker        for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup):
400*da0073e9SAndroid Build Coastguard Worker            assert not ds_dx.is_complex()
401*da0073e9SAndroid Build Coastguard Worker            # conjugate wirtinger derivative
402*da0073e9SAndroid Build Coastguard Worker            conj_w_d = ds_dx + ds_dy * 1j
403*da0073e9SAndroid Build Coastguard Worker            jvps.append(conj_w_d)
404*da0073e9SAndroid Build Coastguard Worker    else:
405*da0073e9SAndroid Build Coastguard Worker        for ds_dx in ds_dx_tup:  # R -> R or (R -> C for the forward AD case)
406*da0073e9SAndroid Build Coastguard Worker            assert is_forward_ad or not ds_dx.is_complex()
407*da0073e9SAndroid Build Coastguard Worker            jvps.append(ds_dx)
408*da0073e9SAndroid Build Coastguard Worker    return jvps
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Workerdef _combine_jacobian_cols(
412*da0073e9SAndroid Build Coastguard Worker    jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel
413*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
414*da0073e9SAndroid Build Coastguard Worker    # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
415*da0073e9SAndroid Build Coastguard Worker    # we return a list that maps output_idx -> full jacobian Tensor
416*da0073e9SAndroid Build Coastguard Worker    jacobians = _allocate_jacobians_with_outputs(
417*da0073e9SAndroid Build Coastguard Worker        outputs, numel, dtype=input.dtype if input.dtype.is_complex else None
418*da0073e9SAndroid Build Coastguard Worker    )
419*da0073e9SAndroid Build Coastguard Worker    for i, jacobian in enumerate(jacobians):
420*da0073e9SAndroid Build Coastguard Worker        for k, v in jacobians_cols.items():
421*da0073e9SAndroid Build Coastguard Worker            jacobian[k] = v[i]
422*da0073e9SAndroid Build Coastguard Worker    return jacobians
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Workerdef _prepare_input(
426*da0073e9SAndroid Build Coastguard Worker    input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False
427*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
428*da0073e9SAndroid Build Coastguard Worker    # Prepares the inputs to be passed into the function while including the new
429*da0073e9SAndroid Build Coastguard Worker    # modified input.
430*da0073e9SAndroid Build Coastguard Worker    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
431*da0073e9SAndroid Build Coastguard Worker        # Convert back to mkldnn
432*da0073e9SAndroid Build Coastguard Worker        if maybe_perturbed_input is not None:
433*da0073e9SAndroid Build Coastguard Worker            return maybe_perturbed_input.to_mkldnn()
434*da0073e9SAndroid Build Coastguard Worker        else:
435*da0073e9SAndroid Build Coastguard Worker            return input
436*da0073e9SAndroid Build Coastguard Worker    elif _is_sparse_any_tensor(input):
437*da0073e9SAndroid Build Coastguard Worker        if fast_mode and maybe_perturbed_input is not None:
438*da0073e9SAndroid Build Coastguard Worker            # entry is already a "cloned" version of the original tensor
439*da0073e9SAndroid Build Coastguard Worker            # thus changes to entry are not reflected in the input
440*da0073e9SAndroid Build Coastguard Worker            return maybe_perturbed_input
441*da0073e9SAndroid Build Coastguard Worker        else:
442*da0073e9SAndroid Build Coastguard Worker            return input
443*da0073e9SAndroid Build Coastguard Worker    else:
444*da0073e9SAndroid Build Coastguard Worker        # We cannot use entry (input.data) if we want gradgrad to work because
445*da0073e9SAndroid Build Coastguard Worker        # fn (in the gradgrad case) needs to compute grad wrt input
446*da0073e9SAndroid Build Coastguard Worker        return input
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Workerdef _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None:
450*da0073e9SAndroid Build Coastguard Worker    # Check that the returned outputs don't have different dtype or shape when you
451*da0073e9SAndroid Build Coastguard Worker    # perturb the input
452*da0073e9SAndroid Build Coastguard Worker    on_index = "on index {idx} " if idx is not None else ""
453*da0073e9SAndroid Build Coastguard Worker    assert output1.shape == output2.shape, (
454*da0073e9SAndroid Build Coastguard Worker        f"Expected `func` to return outputs with the same shape"
455*da0073e9SAndroid Build Coastguard Worker        f" when inputs are perturbed {on_index}by {eps}, but got:"
456*da0073e9SAndroid Build Coastguard Worker        f" shapes {output1.shape} and {output2.shape}."
457*da0073e9SAndroid Build Coastguard Worker    )
458*da0073e9SAndroid Build Coastguard Worker    assert output1.dtype == output2.dtype, (
459*da0073e9SAndroid Build Coastguard Worker        f"Expected `func` to return outputs with the same dtype"
460*da0073e9SAndroid Build Coastguard Worker        f" when inputs are perturbed {on_index}by {eps}, but got:"
461*da0073e9SAndroid Build Coastguard Worker        f" dtypes {output1.dtype} and {output2.dtype}."
462*da0073e9SAndroid Build Coastguard Worker    )
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Workerdef get_numerical_jacobian_wrt_specific_input(
466*da0073e9SAndroid Build Coastguard Worker    fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False
467*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
468*da0073e9SAndroid Build Coastguard Worker    # Computes the numerical jacobians wrt to a single input. Returns N jacobian
469*da0073e9SAndroid Build Coastguard Worker    # tensors, where N is the number of outputs. We use a dictionary for
470*da0073e9SAndroid Build Coastguard Worker    # jacobian_cols because indices aren't necessarily consecutive for sparse inputs
471*da0073e9SAndroid Build Coastguard Worker    # When we perturb only a single element of the input tensor at a time, the jvp
472*da0073e9SAndroid Build Coastguard Worker    # is equivalent to a single col of the Jacobian matrix of fn.
473*da0073e9SAndroid Build Coastguard Worker    jacobian_cols: Dict[int, List[torch.Tensor]] = {}
474*da0073e9SAndroid Build Coastguard Worker    input = inputs[input_idx] if input is None else input
475*da0073e9SAndroid Build Coastguard Worker    assert input.requires_grad
476*da0073e9SAndroid Build Coastguard Worker    for x, idx, d_idx in _iter_tensor(input):
477*da0073e9SAndroid Build Coastguard Worker        wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x)
478*da0073e9SAndroid Build Coastguard Worker        input_to_perturb = x[idx]
479*da0073e9SAndroid Build Coastguard Worker        nbhd_checks_fn = functools.partial(
480*da0073e9SAndroid Build Coastguard Worker            _check_outputs_same_dtype_and_shape, idx=idx, eps=eps
481*da0073e9SAndroid Build Coastguard Worker        )
482*da0073e9SAndroid Build Coastguard Worker        jvp_fn = _get_numerical_jvp_fn(
483*da0073e9SAndroid Build Coastguard Worker            wrapped_fn, input_to_perturb, eps, nbhd_checks_fn
484*da0073e9SAndroid Build Coastguard Worker        )
485*da0073e9SAndroid Build Coastguard Worker        jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input(
486*da0073e9SAndroid Build Coastguard Worker            jvp_fn, eps, x.is_complex(), is_forward_ad
487*da0073e9SAndroid Build Coastguard Worker        )
488*da0073e9SAndroid Build Coastguard Worker    return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel())
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_jacobian_forward_ad(
492*da0073e9SAndroid Build Coastguard Worker    fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None
493*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[torch.Tensor, ...], ...]:
494*da0073e9SAndroid Build Coastguard Worker    """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`.
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    Return N * M Jacobians where N is the number of tensors in target that require grad and
497*da0073e9SAndroid Build Coastguard Worker    M is the number of non-integral outputs.
498*da0073e9SAndroid Build Coastguard Worker    Contrary to other functions here, this function requires "inputs" to actually be used by the function.
499*da0073e9SAndroid Build Coastguard Worker    The computed value is expected to be wrong if the function captures the inputs by side effect instead of
500*da0073e9SAndroid Build Coastguard Worker    using the passed ones (many torch.nn tests do this).
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    Args:
503*da0073e9SAndroid Build Coastguard Worker        fn: the function to compute the jacobian for
504*da0073e9SAndroid Build Coastguard Worker        inputs: inputs to `fn`
505*da0073e9SAndroid Build Coastguard Worker        outputs: provide precomputed outputs to avoid one extra invocation of fn
506*da0073e9SAndroid Build Coastguard Worker        check_grad_dtypes: if True, will check that the gradient dtype are valid
507*da0073e9SAndroid Build Coastguard Worker        all_u (optional): if provided, the Jacobian will be right multiplied with this vector
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker    Returns:
510*da0073e9SAndroid Build Coastguard Worker        A tuple of M N-tuples of tensors
511*da0073e9SAndroid Build Coastguard Worker    """
512*da0073e9SAndroid Build Coastguard Worker    # To avoid early import issues
513*da0073e9SAndroid Build Coastguard Worker    fwAD = torch.autograd.forward_ad
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    if any(i.is_complex() for i in tensor_inputs):
518*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
519*da0073e9SAndroid Build Coastguard Worker            "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad."
520*da0073e9SAndroid Build Coastguard Worker        )
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    if all_u:
523*da0073e9SAndroid Build Coastguard Worker        jacobians = tuple(
524*da0073e9SAndroid Build Coastguard Worker            _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs
525*da0073e9SAndroid Build Coastguard Worker        )
526*da0073e9SAndroid Build Coastguard Worker    else:
527*da0073e9SAndroid Build Coastguard Worker        jacobians = tuple(
528*da0073e9SAndroid Build Coastguard Worker            _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs
529*da0073e9SAndroid Build Coastguard Worker        )
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    with fwAD.dual_level():
532*da0073e9SAndroid Build Coastguard Worker        fw_grads = []
533*da0073e9SAndroid Build Coastguard Worker        dual_inputs = []
534*da0073e9SAndroid Build Coastguard Worker        for i, inp in enumerate(inputs):
535*da0073e9SAndroid Build Coastguard Worker            if is_tensor_like(inp) and inp.requires_grad:
536*da0073e9SAndroid Build Coastguard Worker                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
537*da0073e9SAndroid Build Coastguard Worker                    raise ValueError(
538*da0073e9SAndroid Build Coastguard Worker                        "MKLDNN inputs are not support for forward AD gradcheck."
539*da0073e9SAndroid Build Coastguard Worker                    )
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
542*da0073e9SAndroid Build Coastguard Worker                # If inp is a differentiable view, the dual might not be the tangent given to
543*da0073e9SAndroid Build Coastguard Worker                # make_dual, so read it explicitly from the dual tensor
544*da0073e9SAndroid Build Coastguard Worker                fw_grads.append(fwAD.unpack_dual(inp)[1])
545*da0073e9SAndroid Build Coastguard Worker            dual_inputs.append(inp)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        if all_u:
548*da0073e9SAndroid Build Coastguard Worker            # Do the full reduction in one pass
549*da0073e9SAndroid Build Coastguard Worker            # To be consistent with numerical evaluation, we actually compute one reduction per input
550*da0073e9SAndroid Build Coastguard Worker            for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
551*da0073e9SAndroid Build Coastguard Worker                fw_grad.copy_(u.view_as(fw_grad))
552*da0073e9SAndroid Build Coastguard Worker                raw_outputs = _as_tuple(fn(*dual_inputs))
553*da0073e9SAndroid Build Coastguard Worker                dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
554*da0073e9SAndroid Build Coastguard Worker                for index_o, d_o in enumerate(dual_outputs):
555*da0073e9SAndroid Build Coastguard Worker                    val, res = fwAD.unpack_dual(d_o)
556*da0073e9SAndroid Build Coastguard Worker                    if (
557*da0073e9SAndroid Build Coastguard Worker                        check_grad_dtypes
558*da0073e9SAndroid Build Coastguard Worker                        and res is not None
559*da0073e9SAndroid Build Coastguard Worker                        and val.is_complex() != res.is_complex()
560*da0073e9SAndroid Build Coastguard Worker                    ):
561*da0073e9SAndroid Build Coastguard Worker                        raise GradcheckError("Forward AD gradient has dtype mismatch.")
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker                    # Remove extra dimension of size 1 corresponding to the reduced input
564*da0073e9SAndroid Build Coastguard Worker                    jacobians[i][index_o].squeeze_(0)
565*da0073e9SAndroid Build Coastguard Worker                    if res is None:
566*da0073e9SAndroid Build Coastguard Worker                        jacobians[i][index_o].zero_()
567*da0073e9SAndroid Build Coastguard Worker                    else:
568*da0073e9SAndroid Build Coastguard Worker                        jacobians[i][index_o].copy_(res.reshape(-1))
569*da0073e9SAndroid Build Coastguard Worker                fw_grad.zero_()
570*da0073e9SAndroid Build Coastguard Worker        else:
571*da0073e9SAndroid Build Coastguard Worker            # Reconstruct the full Jacobian column by column
572*da0073e9SAndroid Build Coastguard Worker            for i, fw_grad in enumerate(fw_grads):
573*da0073e9SAndroid Build Coastguard Worker                for lin_idx, grad_idx in enumerate(
574*da0073e9SAndroid Build Coastguard Worker                    product(*[range(m) for m in fw_grad.size()])
575*da0073e9SAndroid Build Coastguard Worker                ):
576*da0073e9SAndroid Build Coastguard Worker                    fw_grad[grad_idx] = 1.0
577*da0073e9SAndroid Build Coastguard Worker                    raw_outputs = _as_tuple(fn(*dual_inputs))
578*da0073e9SAndroid Build Coastguard Worker                    dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
579*da0073e9SAndroid Build Coastguard Worker                    for index_o, d_o in enumerate(dual_outputs):
580*da0073e9SAndroid Build Coastguard Worker                        val, res = fwAD.unpack_dual(d_o)
581*da0073e9SAndroid Build Coastguard Worker                        if (
582*da0073e9SAndroid Build Coastguard Worker                            check_grad_dtypes
583*da0073e9SAndroid Build Coastguard Worker                            and res is not None
584*da0073e9SAndroid Build Coastguard Worker                            and val.is_complex() != res.is_complex()
585*da0073e9SAndroid Build Coastguard Worker                        ):
586*da0073e9SAndroid Build Coastguard Worker                            raise GradcheckError(
587*da0073e9SAndroid Build Coastguard Worker                                "Forward AD gradient has dtype mismatch."
588*da0073e9SAndroid Build Coastguard Worker                            )
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker                        if res is None:
591*da0073e9SAndroid Build Coastguard Worker                            jacobians[i][index_o][lin_idx].zero_()
592*da0073e9SAndroid Build Coastguard Worker                        else:
593*da0073e9SAndroid Build Coastguard Worker                            jacobians[i][index_o][lin_idx].copy_(res.reshape(-1))
594*da0073e9SAndroid Build Coastguard Worker                    fw_grad[grad_idx] = 0.0
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    return jacobians
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Workerdef _get_input_to_perturb(input):
600*da0073e9SAndroid Build Coastguard Worker    # Prepare the input so that it can be modified in-place and do certain
601*da0073e9SAndroid Build Coastguard Worker    # operations that require the tensor to have strides. If fast_mode=False,
602*da0073e9SAndroid Build Coastguard Worker    # _iter_tensor would handle the below cases:
603*da0073e9SAndroid Build Coastguard Worker    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
604*da0073e9SAndroid Build Coastguard Worker        # Convert to dense so we can perform operations that require strided tensors
605*da0073e9SAndroid Build Coastguard Worker        input_to_perturb = input.to_dense()
606*da0073e9SAndroid Build Coastguard Worker    elif _is_sparse_any_tensor(input):
607*da0073e9SAndroid Build Coastguard Worker        # Clone because input may require grad, and copy_ calls resize_,
608*da0073e9SAndroid Build Coastguard Worker        # which is not allowed for .data
609*da0073e9SAndroid Build Coastguard Worker        input_to_perturb = input.clone()
610*da0073e9SAndroid Build Coastguard Worker    else:
611*da0073e9SAndroid Build Coastguard Worker        input_to_perturb = input.data
612*da0073e9SAndroid Build Coastguard Worker    return input_to_perturb
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Workerdef _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False):
616*da0073e9SAndroid Build Coastguard Worker    # Wraps `fn` so that its inputs are already supplied
617*da0073e9SAndroid Build Coastguard Worker    def wrapped_fn():
618*da0073e9SAndroid Build Coastguard Worker        inp = tuple(
619*da0073e9SAndroid Build Coastguard Worker            _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode)
620*da0073e9SAndroid Build Coastguard Worker            if is_tensor_like(a)
621*da0073e9SAndroid Build Coastguard Worker            else a
622*da0073e9SAndroid Build Coastguard Worker            for i, a in enumerate(_as_tuple(inputs))
623*da0073e9SAndroid Build Coastguard Worker        )
624*da0073e9SAndroid Build Coastguard Worker        return tuple(a.clone() for a in _as_tuple(fn(*inp)))
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    return wrapped_fn
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn):
630*da0073e9SAndroid Build Coastguard Worker    # Wraps jvp_fn so that certain arguments are already supplied
631*da0073e9SAndroid Build Coastguard Worker    def jvp_fn(delta):
632*da0073e9SAndroid Build Coastguard Worker        return _compute_numerical_gradient(
633*da0073e9SAndroid Build Coastguard Worker            wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    return jvp_fn
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Workerdef _reshape_tensor_or_tuple(u, shape):
640*da0073e9SAndroid Build Coastguard Worker    # We don't need to reshape when input corresponding to u is sparse
641*da0073e9SAndroid Build Coastguard Worker    if isinstance(u, tuple):
642*da0073e9SAndroid Build Coastguard Worker        if not _is_sparse_any_tensor(u[0]):
643*da0073e9SAndroid Build Coastguard Worker            return (u[0].reshape(shape), u[1].reshape(shape))
644*da0073e9SAndroid Build Coastguard Worker    else:
645*da0073e9SAndroid Build Coastguard Worker        if not _is_sparse_any_tensor(u):
646*da0073e9SAndroid Build Coastguard Worker            return u.reshape(shape)
647*da0073e9SAndroid Build Coastguard Worker    return u
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Workerdef _mul_tensor_or_tuple(u, k):
651*da0073e9SAndroid Build Coastguard Worker    if isinstance(u, tuple):
652*da0073e9SAndroid Build Coastguard Worker        return (k * u[0], k * u[1])
653*da0073e9SAndroid Build Coastguard Worker    else:
654*da0073e9SAndroid Build Coastguard Worker        return k * u
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jvp_wrt_specific_input(
658*da0073e9SAndroid Build Coastguard Worker    fn, input_idx, inputs, u, eps, is_forward_ad=False
659*da0073e9SAndroid Build Coastguard Worker) -> List[torch.Tensor]:
660*da0073e9SAndroid Build Coastguard Worker    input = inputs[input_idx]
661*da0073e9SAndroid Build Coastguard Worker    input_to_perturb = _get_input_to_perturb(input)
662*da0073e9SAndroid Build Coastguard Worker    wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True)
663*da0073e9SAndroid Build Coastguard Worker    nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps)
664*da0073e9SAndroid Build Coastguard Worker    jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn)
665*da0073e9SAndroid Build Coastguard Worker    u = _reshape_tensor_or_tuple(u, input_to_perturb.shape)
666*da0073e9SAndroid Build Coastguard Worker    u = _mul_tensor_or_tuple(u, eps)
667*da0073e9SAndroid Build Coastguard Worker    return _compute_numerical_jvps_wrt_specific_input(
668*da0073e9SAndroid Build Coastguard Worker        jvp_fn, u, input.is_complex(), is_forward_ad
669*da0073e9SAndroid Build Coastguard Worker    )
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_vJu(
673*da0073e9SAndroid Build Coastguard Worker    fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad
674*da0073e9SAndroid Build Coastguard Worker):
675*da0073e9SAndroid Build Coastguard Worker    # Note that all_v can also be None, in that case, this function only computes Ju.
676*da0073e9SAndroid Build Coastguard Worker    reduced_jacobians: List[List[torch.Tensor]] = []
677*da0073e9SAndroid Build Coastguard Worker    for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)):
678*da0073e9SAndroid Build Coastguard Worker        all_Ju = _get_numerical_jvp_wrt_specific_input(
679*da0073e9SAndroid Build Coastguard Worker            fn, inp_idx, inputs, u, eps, is_forward_ad
680*da0073e9SAndroid Build Coastguard Worker        )
681*da0073e9SAndroid Build Coastguard Worker        # Filter out the Ju for non floating point outputs
682*da0073e9SAndroid Build Coastguard Worker        filtered_Ju = []
683*da0073e9SAndroid Build Coastguard Worker        func_out = _as_tuple(func_out)
684*da0073e9SAndroid Build Coastguard Worker        assert len(all_Ju) == len(func_out)
685*da0073e9SAndroid Build Coastguard Worker        for Ju, output in zip(all_Ju, func_out):
686*da0073e9SAndroid Build Coastguard Worker            if _is_float_or_complex_tensor(output):
687*da0073e9SAndroid Build Coastguard Worker                filtered_Ju.append(Ju)
688*da0073e9SAndroid Build Coastguard Worker            else:
689*da0073e9SAndroid Build Coastguard Worker                # TODO: handle the other Ju
690*da0073e9SAndroid Build Coastguard Worker                pass
691*da0073e9SAndroid Build Coastguard Worker        if all_v is not None:
692*da0073e9SAndroid Build Coastguard Worker            jacobian_scalars: List[torch.Tensor] = []
693*da0073e9SAndroid Build Coastguard Worker            for v, Ju in zip(all_v, filtered_Ju):
694*da0073e9SAndroid Build Coastguard Worker                jacobian_scalars.append(_dot_with_type_promotion(v, Ju))
695*da0073e9SAndroid Build Coastguard Worker            reduced_jacobians.append(jacobian_scalars)
696*da0073e9SAndroid Build Coastguard Worker        else:
697*da0073e9SAndroid Build Coastguard Worker            reduced_jacobians.append(filtered_Ju)
698*da0073e9SAndroid Build Coastguard Worker    return reduced_jacobians
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Workerdef _check_jacobians_equal(j1, j2, atol):
702*da0073e9SAndroid Build Coastguard Worker    # Check whether the max difference between two Jacobian tensors are within some
703*da0073e9SAndroid Build Coastguard Worker    # tolerance `atol`.
704*da0073e9SAndroid Build Coastguard Worker    for j1_x, j2_x in zip(j1, j2):
705*da0073e9SAndroid Build Coastguard Worker        if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol:
706*da0073e9SAndroid Build Coastguard Worker            return False
707*da0073e9SAndroid Build Coastguard Worker    return True
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Workerdef _stack_and_check_tensors(
711*da0073e9SAndroid Build Coastguard Worker    list_of_list_of_tensors, inputs, numel_outputs
712*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]:
713*da0073e9SAndroid Build Coastguard Worker    # For the ith tensor in the inner list checks whether it has the same size and
714*da0073e9SAndroid Build Coastguard Worker    # dtype as the ith differentiable input.
715*da0073e9SAndroid Build Coastguard Worker    out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs)
716*da0073e9SAndroid Build Coastguard Worker    diff_input_list = list(_iter_tensors(inputs, True))
717*da0073e9SAndroid Build Coastguard Worker    correct_grad_sizes = True
718*da0073e9SAndroid Build Coastguard Worker    correct_grad_types = True
719*da0073e9SAndroid Build Coastguard Worker    for i, tensor_list in enumerate(list_of_list_of_tensors):
720*da0073e9SAndroid Build Coastguard Worker        inp = diff_input_list[i]
721*da0073e9SAndroid Build Coastguard Worker        out_jacobian = out_jacobians[i]
722*da0073e9SAndroid Build Coastguard Worker        for j, tensor in enumerate(tensor_list):
723*da0073e9SAndroid Build Coastguard Worker            if tensor is not None and tensor.size() != inp.size():
724*da0073e9SAndroid Build Coastguard Worker                correct_grad_sizes = False
725*da0073e9SAndroid Build Coastguard Worker            elif tensor is not None and tensor.dtype != inp.dtype:
726*da0073e9SAndroid Build Coastguard Worker                correct_grad_types = False
727*da0073e9SAndroid Build Coastguard Worker            if tensor is None:
728*da0073e9SAndroid Build Coastguard Worker                out_jacobian[:, j].zero_()
729*da0073e9SAndroid Build Coastguard Worker            else:
730*da0073e9SAndroid Build Coastguard Worker                dense = (
731*da0073e9SAndroid Build Coastguard Worker                    tensor.to_dense() if not tensor.layout == torch.strided else tensor
732*da0073e9SAndroid Build Coastguard Worker                )
733*da0073e9SAndroid Build Coastguard Worker                assert out_jacobian[:, j].numel() == dense.numel()
734*da0073e9SAndroid Build Coastguard Worker                out_jacobian[:, j] = dense.reshape(-1)
735*da0073e9SAndroid Build Coastguard Worker    return out_jacobians, correct_grad_sizes, correct_grad_types
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard WorkerFAILED_NONDET_MSG = """\n
739*da0073e9SAndroid Build Coastguard WorkerNOTE: If your op relies on non-deterministic operations i.e., it is listed here:
740*da0073e9SAndroid Build Coastguard Workerhttps://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
741*da0073e9SAndroid Build Coastguard Workerthis failure might be expected.
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the
744*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
745*da0073e9SAndroid Build Coastguard WorkerIf the test
746*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
747*da0073e9SAndroid Build Coastguard Worker  with `nondet_tol=<tol>` as a keyword argument.
748*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
749*da0073e9SAndroid Build Coastguard Worker  to have `gradcheck_nondet_tol=<tol>`.
750*da0073e9SAndroid Build Coastguard Worker- is a Module test (e.g., in common_nn.py), then modify the corresponding
751*da0073e9SAndroid Build Coastguard Worker  module_test entry to have `gradcheck_nondet_tol=<tol>`
752*da0073e9SAndroid Build Coastguard Worker"""
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Workerdef _check_analytical_jacobian_attributes(
756*da0073e9SAndroid Build Coastguard Worker    inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
757*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]:
758*da0073e9SAndroid Build Coastguard Worker    # This is used by both fast and slow mode:
759*da0073e9SAndroid Build Coastguard Worker    #  - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
760*da0073e9SAndroid Build Coastguard Worker    #    input.
761*da0073e9SAndroid Build Coastguard Worker    #  - For fast mode, vjps[i][0] is a linear combination of the rows
762*da0073e9SAndroid Build Coastguard Worker    #    of the Jacobian wrt the ith input
763*da0073e9SAndroid Build Coastguard Worker    diff_input_list = list(_iter_tensors(inputs, True))
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker    def vjp_fn(grad_output):
766*da0073e9SAndroid Build Coastguard Worker        return torch.autograd.grad(
767*da0073e9SAndroid Build Coastguard Worker            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
768*da0073e9SAndroid Build Coastguard Worker        )
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker    # Compute everything twice to check for nondeterminism (which we call reentrancy)
771*da0073e9SAndroid Build Coastguard Worker    if fast_mode:
772*da0073e9SAndroid Build Coastguard Worker        vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
773*da0073e9SAndroid Build Coastguard Worker        vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
774*da0073e9SAndroid Build Coastguard Worker    else:
775*da0073e9SAndroid Build Coastguard Worker        vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
776*da0073e9SAndroid Build Coastguard Worker        vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker    output_numel = output.numel() if not fast_mode else 1
779*da0073e9SAndroid Build Coastguard Worker    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
780*da0073e9SAndroid Build Coastguard Worker        vjps1, inputs, output_numel
781*da0073e9SAndroid Build Coastguard Worker    )
782*da0073e9SAndroid Build Coastguard Worker    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
783*da0073e9SAndroid Build Coastguard Worker    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    if not types_ok and check_grad_dtypes:
786*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError("Gradient has dtype mismatch")
787*da0073e9SAndroid Build Coastguard Worker    if not sizes_ok:
788*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError("Analytical gradient has incorrect size")
789*da0073e9SAndroid Build Coastguard Worker    if not reentrant:
790*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError(
791*da0073e9SAndroid Build Coastguard Worker            "Backward is not reentrant, i.e., running backward with "
792*da0073e9SAndroid Build Coastguard Worker            "same input and grad_output multiple times gives different values, "
793*da0073e9SAndroid Build Coastguard Worker            "although analytical gradient matches numerical gradient."
794*da0073e9SAndroid Build Coastguard Worker            f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG
795*da0073e9SAndroid Build Coastguard Worker        )
796*da0073e9SAndroid Build Coastguard Worker    return jacobians1
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_vJu_backward_mode(
800*da0073e9SAndroid Build Coastguard Worker    inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u
801*da0073e9SAndroid Build Coastguard Worker):
802*da0073e9SAndroid Build Coastguard Worker    reduced_jacobians: List[List[torch.Tensor]] = []
803*da0073e9SAndroid Build Coastguard Worker    for output, v in zip(outputs, all_v):
804*da0073e9SAndroid Build Coastguard Worker        all_vJ = _check_analytical_jacobian_attributes(
805*da0073e9SAndroid Build Coastguard Worker            inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v
806*da0073e9SAndroid Build Coastguard Worker        )
807*da0073e9SAndroid Build Coastguard Worker        jacobian_scalars: List[torch.Tensor] = []
808*da0073e9SAndroid Build Coastguard Worker        for vJ, u in zip(all_vJ, all_u):
809*da0073e9SAndroid Build Coastguard Worker            # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse
810*da0073e9SAndroid Build Coastguard Worker            # the error checking logic from slow mode
811*da0073e9SAndroid Build Coastguard Worker            vJ = vJ.T.squeeze(0)
812*da0073e9SAndroid Build Coastguard Worker            if vJ.is_complex():  # C -> R
813*da0073e9SAndroid Build Coastguard Worker                tv = torch.view_as_real(vJ.resolve_conj())
814*da0073e9SAndroid Build Coastguard Worker                tr = tv.select(-1, 0)
815*da0073e9SAndroid Build Coastguard Worker                ti = tv.select(-1, 1)
816*da0073e9SAndroid Build Coastguard Worker                jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
817*da0073e9SAndroid Build Coastguard Worker            else:  # R -> R
818*da0073e9SAndroid Build Coastguard Worker                jacobian_scalars.append(vJ.dot(u))
819*da0073e9SAndroid Build Coastguard Worker        reduced_jacobians.append(jacobian_scalars)
820*da0073e9SAndroid Build Coastguard Worker    return reduced_jacobians
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker@deprecated(
824*da0073e9SAndroid Build Coastguard Worker    "`get_analytical_jacobian` was part of PyTorch's private API and not "
825*da0073e9SAndroid Build Coastguard Worker    "meant to be exposed. We are deprecating it and it will be removed "
826*da0073e9SAndroid Build Coastguard Worker    "in a future version of PyTorch. If you have a specific use for "
827*da0073e9SAndroid Build Coastguard Worker    "this or feature request for this to be a stable API, please file "
828*da0073e9SAndroid Build Coastguard Worker    "us an issue at https://github.com/pytorch/pytorch/issues/new",
829*da0073e9SAndroid Build Coastguard Worker    category=FutureWarning,
830*da0073e9SAndroid Build Coastguard Worker)
831*da0073e9SAndroid Build Coastguard Workerdef get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0):
832*da0073e9SAndroid Build Coastguard Worker    # Replicates the behavior of the old get_analytical_jacobian before the refactor
833*da0073e9SAndroid Build Coastguard Worker    # This shares much of its code with _check_analytical_jacobian_attributes
834*da0073e9SAndroid Build Coastguard Worker    if (
835*da0073e9SAndroid Build Coastguard Worker        grad_out != 1.0
836*da0073e9SAndroid Build Coastguard Worker    ):  # grad_out param is only kept for backward compatibility reasons
837*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
838*da0073e9SAndroid Build Coastguard Worker            "Expected grad_out to be 1.0. get_analytical_jacobian no longer "
839*da0073e9SAndroid Build Coastguard Worker            "supports values of grad_out != 1.0."
840*da0073e9SAndroid Build Coastguard Worker        )
841*da0073e9SAndroid Build Coastguard Worker    if output.is_complex():
842*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
843*da0073e9SAndroid Build Coastguard Worker            "Expected output to be non-complex. get_analytical_jacobian no "
844*da0073e9SAndroid Build Coastguard Worker            "longer supports functions that return complex outputs."
845*da0073e9SAndroid Build Coastguard Worker        )
846*da0073e9SAndroid Build Coastguard Worker    diff_input_list = list(_iter_tensors(inputs, True))
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    def vjp_fn(grad_output):
849*da0073e9SAndroid Build Coastguard Worker        return torch.autograd.grad(
850*da0073e9SAndroid Build Coastguard Worker            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
851*da0073e9SAndroid Build Coastguard Worker        )
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    # Compute everything twice to check for nondeterminism (which we call reentrancy)
854*da0073e9SAndroid Build Coastguard Worker    vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
855*da0073e9SAndroid Build Coastguard Worker    vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker    output_numel = output.numel()
858*da0073e9SAndroid Build Coastguard Worker    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
859*da0073e9SAndroid Build Coastguard Worker        vjps1, inputs, output_numel
860*da0073e9SAndroid Build Coastguard Worker    )
861*da0073e9SAndroid Build Coastguard Worker    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
862*da0073e9SAndroid Build Coastguard Worker    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker    return jacobians1, reentrant, sizes_ok, types_ok
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
868*da0073e9SAndroid Build Coastguard Worker    # Computes the analytical Jacobian in slow mode for a single input-output pair.
869*da0073e9SAndroid Build Coastguard Worker    # Forgoes performing checks on dtype, shape, and reentrancy.
870*da0073e9SAndroid Build Coastguard Worker    jacobians = _check_analytical_jacobian_attributes(
871*da0073e9SAndroid Build Coastguard Worker        inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False
872*da0073e9SAndroid Build Coastguard Worker    )
873*da0073e9SAndroid Build Coastguard Worker    return jacobians[input_idx]
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Workerdef _compute_analytical_jacobian_rows(
877*da0073e9SAndroid Build Coastguard Worker    vjp_fn, sample_output
878*da0073e9SAndroid Build Coastguard Worker) -> List[List[Optional[torch.Tensor]]]:
879*da0073e9SAndroid Build Coastguard Worker    # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
880*da0073e9SAndroid Build Coastguard Worker    # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
881*da0073e9SAndroid Build Coastguard Worker    # NB: this function does not assume vjp_fn(v) to return tensors with the same
882*da0073e9SAndroid Build Coastguard Worker    # number of elements for different v. This is checked when we later combine the
883*da0073e9SAndroid Build Coastguard Worker    # rows into a single tensor.
884*da0073e9SAndroid Build Coastguard Worker    grad_out_base = torch.zeros_like(
885*da0073e9SAndroid Build Coastguard Worker        sample_output, memory_format=torch.legacy_contiguous_format
886*da0073e9SAndroid Build Coastguard Worker    )
887*da0073e9SAndroid Build Coastguard Worker    flat_grad_out = grad_out_base.view(-1)
888*da0073e9SAndroid Build Coastguard Worker    # jacobians_rows[i][j] is the Jacobian jth row for the ith input
889*da0073e9SAndroid Build Coastguard Worker    jacobians_rows: List[List[Optional[torch.Tensor]]] = []
890*da0073e9SAndroid Build Coastguard Worker    for j in range(flat_grad_out.numel()):
891*da0073e9SAndroid Build Coastguard Worker        flat_grad_out.zero_()
892*da0073e9SAndroid Build Coastguard Worker        flat_grad_out[j] = 1.0  # projection for jth row of Jacobian
893*da0073e9SAndroid Build Coastguard Worker        grad_inputs = vjp_fn(grad_out_base)
894*da0073e9SAndroid Build Coastguard Worker        for i, d_x in enumerate(grad_inputs):
895*da0073e9SAndroid Build Coastguard Worker            if j == 0:
896*da0073e9SAndroid Build Coastguard Worker                jacobians_rows.append([])
897*da0073e9SAndroid Build Coastguard Worker            jacobians_rows[i] += [
898*da0073e9SAndroid Build Coastguard Worker                d_x.clone() if isinstance(d_x, torch.Tensor) else None
899*da0073e9SAndroid Build Coastguard Worker            ]
900*da0073e9SAndroid Build Coastguard Worker    return jacobians_rows
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_vjps_wrt_specific_output(
904*da0073e9SAndroid Build Coastguard Worker    vjp_fn, sample_output, v
905*da0073e9SAndroid Build Coastguard Worker) -> List[List[Optional[torch.Tensor]]]:
906*da0073e9SAndroid Build Coastguard Worker    vjps: List[List[Optional[torch.Tensor]]] = []
907*da0073e9SAndroid Build Coastguard Worker    grad_inputs = vjp_fn(v.reshape(sample_output.shape))
908*da0073e9SAndroid Build Coastguard Worker    for vjp in grad_inputs:
909*da0073e9SAndroid Build Coastguard Worker        vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None])
910*da0073e9SAndroid Build Coastguard Worker    return vjps
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Workerdef _check_inputs(tupled_inputs) -> bool:
914*da0073e9SAndroid Build Coastguard Worker    # Make sure that gradients are saved for at least one input
915*da0073e9SAndroid Build Coastguard Worker    any_input_requiring_grad = False
916*da0073e9SAndroid Build Coastguard Worker    for idx, inp in enumerate(tupled_inputs):
917*da0073e9SAndroid Build Coastguard Worker        if is_tensor_like(inp) and inp.requires_grad:
918*da0073e9SAndroid Build Coastguard Worker            if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
919*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
920*da0073e9SAndroid Build Coastguard Worker                    f"Input #{idx} requires gradient and "
921*da0073e9SAndroid Build Coastguard Worker                    "is not a double precision floating point or complex. "
922*da0073e9SAndroid Build Coastguard Worker                    "This check will likely fail if all the inputs are "
923*da0073e9SAndroid Build Coastguard Worker                    "not of double precision floating point or complex. "
924*da0073e9SAndroid Build Coastguard Worker                )
925*da0073e9SAndroid Build Coastguard Worker            if inp.is_sparse:
926*da0073e9SAndroid Build Coastguard Worker                content = inp._values()
927*da0073e9SAndroid Build Coastguard Worker            elif _is_sparse_compressed_tensor(inp):
928*da0073e9SAndroid Build Coastguard Worker                content = inp.values()
929*da0073e9SAndroid Build Coastguard Worker            else:
930*da0073e9SAndroid Build Coastguard Worker                content = inp
931*da0073e9SAndroid Build Coastguard Worker            # TODO: To cover more problematic cases, replace stride = 0 check with
932*da0073e9SAndroid Build Coastguard Worker            # "any overlap in memory" once we have a proper function to check it.
933*da0073e9SAndroid Build Coastguard Worker            if content.layout is not torch._mkldnn:  # type: ignore[attr-defined]
934*da0073e9SAndroid Build Coastguard Worker                if not all(
935*da0073e9SAndroid Build Coastguard Worker                    st > 0 or sz <= 1
936*da0073e9SAndroid Build Coastguard Worker                    for st, sz in zip(content.stride(), content.size())
937*da0073e9SAndroid Build Coastguard Worker                ):
938*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
939*da0073e9SAndroid Build Coastguard Worker                        f"The {idx}th input has a dimension with stride 0. gradcheck only "
940*da0073e9SAndroid Build Coastguard Worker                        "supports inputs that are non-overlapping to be able to "
941*da0073e9SAndroid Build Coastguard Worker                        "compute the numerical gradients correctly. You should call "
942*da0073e9SAndroid Build Coastguard Worker                        ".contiguous on the input before passing it to gradcheck."
943*da0073e9SAndroid Build Coastguard Worker                    )
944*da0073e9SAndroid Build Coastguard Worker            any_input_requiring_grad = True
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker    if not any_input_requiring_grad:
947*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
948*da0073e9SAndroid Build Coastguard Worker            "gradcheck expects at least one input tensor to require gradient, "
949*da0073e9SAndroid Build Coastguard Worker            "but none of the them have requires_grad=True."
950*da0073e9SAndroid Build Coastguard Worker        )
951*da0073e9SAndroid Build Coastguard Worker    return True
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Workerdef _check_outputs(outputs) -> None:
955*da0073e9SAndroid Build Coastguard Worker    if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)):
956*da0073e9SAndroid Build Coastguard Worker        # it is easier to call to_dense() on the sparse output than
957*da0073e9SAndroid Build Coastguard Worker        # to modify analytical jacobian
958*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
959*da0073e9SAndroid Build Coastguard Worker            "Sparse output is not supported at gradcheck yet. "
960*da0073e9SAndroid Build Coastguard Worker            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
961*da0073e9SAndroid Build Coastguard Worker        )
962*da0073e9SAndroid Build Coastguard Worker    if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)):  # type: ignore[attr-defined]
963*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
964*da0073e9SAndroid Build Coastguard Worker            "MKLDNN output is not supported at gradcheck yet. "
965*da0073e9SAndroid Build Coastguard Worker            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
966*da0073e9SAndroid Build Coastguard Worker        )
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Workerdef _check_no_differentiable_outputs(
970*da0073e9SAndroid Build Coastguard Worker    func, inputs, func_out, eps, *, is_forward_ad
971*da0073e9SAndroid Build Coastguard Worker) -> bool:
972*da0073e9SAndroid Build Coastguard Worker    # When there are no differentiable outputs, numerical gradient for a function is
973*da0073e9SAndroid Build Coastguard Worker    # expected to be zero.
974*da0073e9SAndroid Build Coastguard Worker    jacobians_all_inputs_outputs = _get_numerical_jacobian(
975*da0073e9SAndroid Build Coastguard Worker        func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad
976*da0073e9SAndroid Build Coastguard Worker    )
977*da0073e9SAndroid Build Coastguard Worker    for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
978*da0073e9SAndroid Build Coastguard Worker        for jacobian in jacobians_all_outputs_and_fixed_input:
979*da0073e9SAndroid Build Coastguard Worker            if torch.ne(jacobian, 0).sum() > 0:
980*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError(
981*da0073e9SAndroid Build Coastguard Worker                    "Numerical gradient for function expected to be zero"
982*da0073e9SAndroid Build Coastguard Worker                )
983*da0073e9SAndroid Build Coastguard Worker    return True
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Workerdef _check_no_differentiable_outputs_fast(
987*da0073e9SAndroid Build Coastguard Worker    func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol
988*da0073e9SAndroid Build Coastguard Worker):
989*da0073e9SAndroid Build Coastguard Worker    for inp_idx, u in zip(inputs_indices, all_u):
990*da0073e9SAndroid Build Coastguard Worker        jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps)
991*da0073e9SAndroid Build Coastguard Worker        for jvp in jvps:
992*da0073e9SAndroid Build Coastguard Worker            if jvp.numel() == 0:
993*da0073e9SAndroid Build Coastguard Worker                continue
994*da0073e9SAndroid Build Coastguard Worker            if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol:
995*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError(
996*da0073e9SAndroid Build Coastguard Worker                    "Numerical gradient for function expected to be zero"
997*da0073e9SAndroid Build Coastguard Worker                )
998*da0073e9SAndroid Build Coastguard Worker    return True
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard WorkerFAILED_BATCHED_GRAD_MSG = """
1002*da0073e9SAndroid Build Coastguard Workergradcheck or gradgradcheck failed while testing batched gradient computation.
1003*da0073e9SAndroid Build Coastguard WorkerThis could have been invoked in a number of ways (via a test that calls
1004*da0073e9SAndroid Build Coastguard Workergradcheck/gradgradcheck directly or via an autogenerated test).
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the
1007*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
1008*da0073e9SAndroid Build Coastguard WorkerIf the test
1009*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1010*da0073e9SAndroid Build Coastguard Worker  with `check_batched_grad=False` as a keyword argument.
1011*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1012*da0073e9SAndroid Build Coastguard Worker  to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard WorkerIf you're modifying an existing operator that supports batched grad computation,
1015*da0073e9SAndroid Build Coastguard Workeror wish to make a new operator work with batched grad computation, please read
1016*da0073e9SAndroid Build Coastguard Workerthe following.
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard WorkerTo compute batched grads (e.g., jacobians, hessians), we vmap over the backward
1019*da0073e9SAndroid Build Coastguard Workercomputation. The most common failure case is if there is a 'vmap-incompatible
1020*da0073e9SAndroid Build Coastguard Workeroperation' in the backward pass. Please see
1021*da0073e9SAndroid Build Coastguard WorkerNOTE: [How to write vmap-compatible backward formulas]
1022*da0073e9SAndroid Build Coastguard Workerin the codebase for an explanation of how to fix this.
1023*da0073e9SAndroid Build Coastguard Worker""".strip()
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard WorkerFAILED_BATCHED_GRAD_MSG_FWD_AD = """
1026*da0073e9SAndroid Build Coastguard Workergradcheck failed while testing batched gradient computation with forward-mode AD.
1027*da0073e9SAndroid Build Coastguard WorkerThis test is enabled automatically when both `check_batched_grad=True`
1028*da0073e9SAndroid Build Coastguard Workerand `check_forward_ad=True`, but can be disabled in the following ways
1029*da0073e9SAndroid Build Coastguard Workerdependong on how the test was invoked (via a test that calls gradcheck
1030*da0073e9SAndroid Build Coastguard Workerdirectly or via an autogenerated test).
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the
1033*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
1034*da0073e9SAndroid Build Coastguard WorkerIf the test
1035*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1036*da0073e9SAndroid Build Coastguard Worker  with `check_batched_forward_grad=False` as a keyword argument.
1037*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1038*da0073e9SAndroid Build Coastguard Worker  to have `check_batched_forward_grad=False`
1039*da0073e9SAndroid Build Coastguard Worker"""
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Workerdef _get_failed_batched_grad_test_msg(
1043*da0073e9SAndroid Build Coastguard Worker    output_idx, input_idx, res, exp, is_forward_ad=False
1044*da0073e9SAndroid Build Coastguard Worker):
1045*da0073e9SAndroid Build Coastguard Worker    return f"""
1046*da0073e9SAndroid Build Coastguard WorkerFor output {output_idx} and input {input_idx}:
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG}
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard WorkerGot:
1051*da0073e9SAndroid Build Coastguard Worker{res}
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard WorkerExpected:
1054*da0073e9SAndroid Build Coastguard Worker{exp}
1055*da0073e9SAndroid Build Coastguard Worker""".strip()
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Workerdef _test_batched_grad_forward_ad(func, inputs) -> bool:
1059*da0073e9SAndroid Build Coastguard Worker    fwAD = torch.autograd.forward_ad  # To avoid early import issues (do we need this?)
1060*da0073e9SAndroid Build Coastguard Worker    assert isinstance(inputs, tuple)
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker    for input_idx, current_input in enumerate(inputs):
1063*da0073e9SAndroid Build Coastguard Worker        if not (is_tensor_like(current_input) and current_input.requires_grad):
1064*da0073e9SAndroid Build Coastguard Worker            continue
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        def jvp(tangent: torch.Tensor):
1067*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
1068*da0073e9SAndroid Build Coastguard Worker                dual = fwAD.make_dual(current_input.detach(), tangent)
1069*da0073e9SAndroid Build Coastguard Worker                inputs_with_dual = tuple(
1070*da0073e9SAndroid Build Coastguard Worker                    dual
1071*da0073e9SAndroid Build Coastguard Worker                    if idx == input_idx
1072*da0073e9SAndroid Build Coastguard Worker                    else (inp.detach() if is_tensor_like(inp) else inp)
1073*da0073e9SAndroid Build Coastguard Worker                    for idx, inp in enumerate(inputs)
1074*da0073e9SAndroid Build Coastguard Worker                )
1075*da0073e9SAndroid Build Coastguard Worker                dual_outputs = _as_tuple(func(*inputs_with_dual))
1076*da0073e9SAndroid Build Coastguard Worker                ret = []
1077*da0073e9SAndroid Build Coastguard Worker                for dual_output in dual_outputs:
1078*da0073e9SAndroid Build Coastguard Worker                    if dual_output is None:
1079*da0073e9SAndroid Build Coastguard Worker                        continue
1080*da0073e9SAndroid Build Coastguard Worker                    primal_out, tangent_out = fwAD.unpack_dual(dual_output)
1081*da0073e9SAndroid Build Coastguard Worker                    if tangent_out is not None:
1082*da0073e9SAndroid Build Coastguard Worker                        ret.append(tangent_out)
1083*da0073e9SAndroid Build Coastguard Worker                    else:
1084*da0073e9SAndroid Build Coastguard Worker                        ret.append(
1085*da0073e9SAndroid Build Coastguard Worker                            torch.zeros(
1086*da0073e9SAndroid Build Coastguard Worker                                [], dtype=primal_out.dtype, device=primal_out.device
1087*da0073e9SAndroid Build Coastguard Worker                            ).expand(primal_out.shape)
1088*da0073e9SAndroid Build Coastguard Worker                        )
1089*da0073e9SAndroid Build Coastguard Worker                return tuple(ret)
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        if not _is_float_or_complex_tensor(current_input):
1092*da0073e9SAndroid Build Coastguard Worker            continue
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        tangents = [torch.randn_like(current_input) for _ in range(2)]
1095*da0073e9SAndroid Build Coastguard Worker        expected = [jvp(t) for t in tangents]
1096*da0073e9SAndroid Build Coastguard Worker        expected = [torch.stack(shards) for shards in zip(*expected)]
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker        try:
1099*da0073e9SAndroid Build Coastguard Worker            result = _vmap(jvp)(torch.stack(tangents))
1100*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as ex:
1101*da0073e9SAndroid Build Coastguard Worker            # Rethrow to provide a better error message
1102*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError(
1103*da0073e9SAndroid Build Coastguard Worker                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}"
1104*da0073e9SAndroid Build Coastguard Worker            ) from ex
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        for input_idx, (res, exp) in enumerate(zip(result, expected)):
1107*da0073e9SAndroid Build Coastguard Worker            if torch.allclose(res, exp):
1108*da0073e9SAndroid Build Coastguard Worker                continue
1109*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError(
1110*da0073e9SAndroid Build Coastguard Worker                _get_failed_batched_grad_test_msg(
1111*da0073e9SAndroid Build Coastguard Worker                    input_idx, input_idx, res, exp, is_forward_ad=True
1112*da0073e9SAndroid Build Coastguard Worker                )
1113*da0073e9SAndroid Build Coastguard Worker            )
1114*da0073e9SAndroid Build Coastguard Worker    return True
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Workerdef _test_batched_grad(input, output, output_idx) -> bool:
1118*da0073e9SAndroid Build Coastguard Worker    # NB: _test_batched_grad compares two autograd.grad invocations with a single
1119*da0073e9SAndroid Build Coastguard Worker    # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the
1120*da0073e9SAndroid Build Coastguard Worker    # sense that we're not comparing an analytical jacobian with a numeric one,
1121*da0073e9SAndroid Build Coastguard Worker    # but it is morally similar (we could have computed a full analytic jac
1122*da0073e9SAndroid Build Coastguard Worker    # via vmap, but that is potentially slow)
1123*da0073e9SAndroid Build Coastguard Worker    diff_input_list = list(_iter_tensors(input, True))
1124*da0073e9SAndroid Build Coastguard Worker    grad = functools.partial(
1125*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad,
1126*da0073e9SAndroid Build Coastguard Worker        output,
1127*da0073e9SAndroid Build Coastguard Worker        diff_input_list,
1128*da0073e9SAndroid Build Coastguard Worker        retain_graph=True,
1129*da0073e9SAndroid Build Coastguard Worker        allow_unused=True,
1130*da0073e9SAndroid Build Coastguard Worker    )
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker    def vjp(v):
1133*da0073e9SAndroid Build Coastguard Worker        results = grad(v)
1134*da0073e9SAndroid Build Coastguard Worker        results = tuple(
1135*da0073e9SAndroid Build Coastguard Worker            grad
1136*da0073e9SAndroid Build Coastguard Worker            if grad is not None
1137*da0073e9SAndroid Build Coastguard Worker            else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape)
1138*da0073e9SAndroid Build Coastguard Worker            for grad, inp in zip(results, diff_input_list)
1139*da0073e9SAndroid Build Coastguard Worker        )
1140*da0073e9SAndroid Build Coastguard Worker        return results
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker    grad_outputs = [torch.randn_like(output) for _ in range(2)]
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker    expected = [vjp(gO) for gO in grad_outputs]
1145*da0073e9SAndroid Build Coastguard Worker    expected = [torch.stack(shards) for shards in zip(*expected)]
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    # Squash warnings since these are expected to happen in most cases
1148*da0073e9SAndroid Build Coastguard Worker    # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
1149*da0073e9SAndroid Build Coastguard Worker    with warnings.catch_warnings():
1150*da0073e9SAndroid Build Coastguard Worker        warnings.filterwarnings("ignore", message="There is a performance drop")
1151*da0073e9SAndroid Build Coastguard Worker        warnings.filterwarnings("ignore", message="Please use torch.vmap")
1152*da0073e9SAndroid Build Coastguard Worker        try:
1153*da0073e9SAndroid Build Coastguard Worker            result = vmap(vjp)(torch.stack(grad_outputs))
1154*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as ex:
1155*da0073e9SAndroid Build Coastguard Worker            # It's OK that we're not raising the error at the correct callsite.
1156*da0073e9SAndroid Build Coastguard Worker            # That's because the callsite is always going to inside the Python
1157*da0073e9SAndroid Build Coastguard Worker            # autograd.grad instead of the C++ traceback of what line in the
1158*da0073e9SAndroid Build Coastguard Worker            # backward formula
1159*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError(
1160*da0073e9SAndroid Build Coastguard Worker                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}"
1161*da0073e9SAndroid Build Coastguard Worker            ) from ex
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker    for input_idx, (res, exp) in enumerate(zip(result, expected)):
1164*da0073e9SAndroid Build Coastguard Worker        if torch.allclose(res, exp):
1165*da0073e9SAndroid Build Coastguard Worker            continue
1166*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError(
1167*da0073e9SAndroid Build Coastguard Worker            _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp)
1168*da0073e9SAndroid Build Coastguard Worker        )
1169*da0073e9SAndroid Build Coastguard Worker    return True
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Workerdef _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool:
1173*da0073e9SAndroid Build Coastguard Worker    # Tests that backward is multiplied by grad_output
1174*da0073e9SAndroid Build Coastguard Worker    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
1175*da0073e9SAndroid Build Coastguard Worker    if not diff_input_list:
1176*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError("no Tensors requiring grad found in input")
1177*da0073e9SAndroid Build Coastguard Worker    grads_input = torch.autograd.grad(
1178*da0073e9SAndroid Build Coastguard Worker        outputs,
1179*da0073e9SAndroid Build Coastguard Worker        diff_input_list,
1180*da0073e9SAndroid Build Coastguard Worker        [
1181*da0073e9SAndroid Build Coastguard Worker            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
1182*da0073e9SAndroid Build Coastguard Worker            for o in outputs
1183*da0073e9SAndroid Build Coastguard Worker        ],
1184*da0073e9SAndroid Build Coastguard Worker        allow_unused=True,
1185*da0073e9SAndroid Build Coastguard Worker    )
1186*da0073e9SAndroid Build Coastguard Worker    for gi, di in zip(grads_input, diff_input_list):
1187*da0073e9SAndroid Build Coastguard Worker        if gi is None:
1188*da0073e9SAndroid Build Coastguard Worker            continue
1189*da0073e9SAndroid Build Coastguard Worker        if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
1190*da0073e9SAndroid Build Coastguard Worker            if gi.layout != di.layout:
1191*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError(
1192*da0073e9SAndroid Build Coastguard Worker                    "grad is incorrect layout ("
1193*da0073e9SAndroid Build Coastguard Worker                    + str(gi.layout)
1194*da0073e9SAndroid Build Coastguard Worker                    + " is not "
1195*da0073e9SAndroid Build Coastguard Worker                    + str(di.layout)
1196*da0073e9SAndroid Build Coastguard Worker                    + ")"
1197*da0073e9SAndroid Build Coastguard Worker                )
1198*da0073e9SAndroid Build Coastguard Worker            if _is_sparse_any_tensor(gi):
1199*da0073e9SAndroid Build Coastguard Worker                sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "")
1200*da0073e9SAndroid Build Coastguard Worker                if gi.sparse_dim() != di.sparse_dim():
1201*da0073e9SAndroid Build Coastguard Worker                    raise GradcheckError(
1202*da0073e9SAndroid Build Coastguard Worker                        f"grad is {sparse_kind} tensor, but has incorrect sparse_dim"
1203*da0073e9SAndroid Build Coastguard Worker                        f" {gi.sparse_dim()}, expected {di.sparse_dim()}"
1204*da0073e9SAndroid Build Coastguard Worker                    )
1205*da0073e9SAndroid Build Coastguard Worker                if gi.dense_dim() != di.dense_dim():
1206*da0073e9SAndroid Build Coastguard Worker                    raise GradcheckError(
1207*da0073e9SAndroid Build Coastguard Worker                        f"grad is {sparse_kind} tensor, but has incorrect dense_dim"
1208*da0073e9SAndroid Build Coastguard Worker                        f" {gi.dense_dim()}, expected {di.dense_dim()}"
1209*da0073e9SAndroid Build Coastguard Worker                    )
1210*da0073e9SAndroid Build Coastguard Worker            gi = gi.to_dense()
1211*da0073e9SAndroid Build Coastguard Worker            di = di.to_dense()
1212*da0073e9SAndroid Build Coastguard Worker        if masked:
1213*da0073e9SAndroid Build Coastguard Worker            if not torch.allclose(gi, torch.zeros_like(gi)):
1214*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError("backward not multiplied by grad_output")
1215*da0073e9SAndroid Build Coastguard Worker        elif not gi.eq(0).all():
1216*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError("backward not multiplied by grad_output")
1217*da0073e9SAndroid Build Coastguard Worker        if gi.dtype != di.dtype:
1218*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError("grad is incorrect type")
1219*da0073e9SAndroid Build Coastguard Worker        if gi.device != di.device:
1220*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError("grad is incorrect device")
1221*da0073e9SAndroid Build Coastguard Worker        if gi.size() != di.size():
1222*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError("grad is incorrect size")
1223*da0073e9SAndroid Build Coastguard Worker    return True
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Workerdef _test_undefined_forward_mode(func, outputs, inputs):
1227*da0073e9SAndroid Build Coastguard Worker    fwAD = torch.autograd.forward_ad
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
1230*da0073e9SAndroid Build Coastguard Worker    all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True)
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker    with fwAD.dual_level():
1235*da0073e9SAndroid Build Coastguard Worker        fw_grads = []
1236*da0073e9SAndroid Build Coastguard Worker        dual_inputs = []
1237*da0073e9SAndroid Build Coastguard Worker        tensor_indices = set()
1238*da0073e9SAndroid Build Coastguard Worker        for i, inp in enumerate(inputs):
1239*da0073e9SAndroid Build Coastguard Worker            if is_tensor_like(inp) and inp.requires_grad:
1240*da0073e9SAndroid Build Coastguard Worker                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
1241*da0073e9SAndroid Build Coastguard Worker                    raise ValueError(
1242*da0073e9SAndroid Build Coastguard Worker                        "MKLDNN inputs are not support for forward AD gradcheck."
1243*da0073e9SAndroid Build Coastguard Worker                    )
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
1246*da0073e9SAndroid Build Coastguard Worker                # If inp is a differentiable view, the dual might not be the tangent given to
1247*da0073e9SAndroid Build Coastguard Worker                # make_dual, so read it explicitly from the dual tensor
1248*da0073e9SAndroid Build Coastguard Worker                fw_grads.append(fwAD.unpack_dual(inp)[1])
1249*da0073e9SAndroid Build Coastguard Worker                tensor_indices.add(i)
1250*da0073e9SAndroid Build Coastguard Worker            dual_inputs.append(inp)
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker        for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
1253*da0073e9SAndroid Build Coastguard Worker            fw_grad.copy_(u.view_as(fw_grad))
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        for idx, inp in enumerate(inputs):
1256*da0073e9SAndroid Build Coastguard Worker            if idx not in tensor_indices:
1257*da0073e9SAndroid Build Coastguard Worker                continue
1258*da0073e9SAndroid Build Coastguard Worker            dual_inp_obj = dual_inputs[idx]
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker            # case 1 (Materialized Zero Tensor Tangent)
1261*da0073e9SAndroid Build Coastguard Worker            dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
1262*da0073e9SAndroid Build Coastguard Worker            raw_outputs = _as_tuple(func(*dual_inputs))
1263*da0073e9SAndroid Build Coastguard Worker            dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker            # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
1266*da0073e9SAndroid Build Coastguard Worker            dual_inputs[idx] = inp.detach()
1267*da0073e9SAndroid Build Coastguard Worker            raw_outputs = _as_tuple(func(*dual_inputs))
1268*da0073e9SAndroid Build Coastguard Worker            dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker            # reset
1271*da0073e9SAndroid Build Coastguard Worker            dual_inputs[idx] = dual_inp_obj
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker            for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)):
1274*da0073e9SAndroid Build Coastguard Worker                val1, res1 = fwAD.unpack_dual(d_o1)
1275*da0073e9SAndroid Build Coastguard Worker                val2, res2 = fwAD.unpack_dual(d_o2)
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker                if not (res1 is None or res2 is None):
1278*da0073e9SAndroid Build Coastguard Worker                    if not torch.allclose(res1, res2):
1279*da0073e9SAndroid Build Coastguard Worker                        raise GradcheckError(
1280*da0073e9SAndroid Build Coastguard Worker                            "Mismatch in tangent values for output with index: ",
1281*da0073e9SAndroid Build Coastguard Worker                            index_o,
1282*da0073e9SAndroid Build Coastguard Worker                            " when input: ",
1283*da0073e9SAndroid Build Coastguard Worker                            inp,
1284*da0073e9SAndroid Build Coastguard Worker                            " has an undefined tangent value. ",
1285*da0073e9SAndroid Build Coastguard Worker                            " Got: ",
1286*da0073e9SAndroid Build Coastguard Worker                            res1,
1287*da0073e9SAndroid Build Coastguard Worker                            " but expected: ",
1288*da0073e9SAndroid Build Coastguard Worker                            res2,
1289*da0073e9SAndroid Build Coastguard Worker                        )
1290*da0073e9SAndroid Build Coastguard Worker    return True
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Workerdef _test_undefined_backward_mode(func, outputs, inputs) -> bool:
1294*da0073e9SAndroid Build Coastguard Worker    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
1295*da0073e9SAndroid Build Coastguard Worker    if not diff_input_list:
1296*da0073e9SAndroid Build Coastguard Worker        raise GradcheckError("no Tensors requiring grad found in input")
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def warn_bc_breaking():
1299*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
1300*da0073e9SAndroid Build Coastguard Worker            "Backwards compatibility: New undefined gradient support checking "
1301*da0073e9SAndroid Build Coastguard Worker            "feature is enabled by default, but it may break existing callers "
1302*da0073e9SAndroid Build Coastguard Worker            "of this function. If this is true for you, you can call this "
1303*da0073e9SAndroid Build Coastguard Worker            'function with "check_undefined_grad=False" to disable the feature'
1304*da0073e9SAndroid Build Coastguard Worker        )
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker    def check_undefined_grad_support(output_to_check):
1307*da0073e9SAndroid Build Coastguard Worker        grads_output = [
1308*da0073e9SAndroid Build Coastguard Worker            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
1309*da0073e9SAndroid Build Coastguard Worker            for o in output_to_check
1310*da0073e9SAndroid Build Coastguard Worker        ]
1311*da0073e9SAndroid Build Coastguard Worker        try:
1312*da0073e9SAndroid Build Coastguard Worker            grads_input = torch.autograd.grad(
1313*da0073e9SAndroid Build Coastguard Worker                output_to_check, diff_input_list, grads_output, allow_unused=True
1314*da0073e9SAndroid Build Coastguard Worker            )
1315*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1316*da0073e9SAndroid Build Coastguard Worker            warn_bc_breaking()
1317*da0073e9SAndroid Build Coastguard Worker            raise GradcheckError(
1318*da0073e9SAndroid Build Coastguard Worker                "Expected backward function to handle undefined output grads. "
1319*da0073e9SAndroid Build Coastguard Worker                'Please look at "Notes about undefined output gradients" in '
1320*da0073e9SAndroid Build Coastguard Worker                '"tools/autograd/derivatives.yaml"'
1321*da0073e9SAndroid Build Coastguard Worker            ) from e
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        for gi, i in zip(grads_input, diff_input_list):
1324*da0073e9SAndroid Build Coastguard Worker            if (gi is not None) and (not gi.eq(0).all()):
1325*da0073e9SAndroid Build Coastguard Worker                warn_bc_breaking()
1326*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError(
1327*da0073e9SAndroid Build Coastguard Worker                    "Expected all input grads to be undefined or zero when all output grads are undefined "
1328*da0073e9SAndroid Build Coastguard Worker                    'or zero. Please look at "Notes about undefined output gradients" in '
1329*da0073e9SAndroid Build Coastguard Worker                    '"tools/autograd/derivatives.yaml"'
1330*da0073e9SAndroid Build Coastguard Worker                )
1331*da0073e9SAndroid Build Coastguard Worker        return True
1332*da0073e9SAndroid Build Coastguard Worker
1333*da0073e9SAndroid Build Coastguard Worker    # All backward functions must work properly if all output grads are undefined
1334*da0073e9SAndroid Build Coastguard Worker    outputs_to_check = [
1335*da0073e9SAndroid Build Coastguard Worker        [
1336*da0073e9SAndroid Build Coastguard Worker            torch._C._functions.UndefinedGrad()(o)
1337*da0073e9SAndroid Build Coastguard Worker            for o in _differentiable_outputs(func(*inputs))
1338*da0073e9SAndroid Build Coastguard Worker            # This check filters out Tensor-likes that aren't instances of Tensor.
1339*da0073e9SAndroid Build Coastguard Worker            if isinstance(o, torch.Tensor)
1340*da0073e9SAndroid Build Coastguard Worker        ]
1341*da0073e9SAndroid Build Coastguard Worker    ]
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker    # If there are multiple output grads, we should be able to undef one at a time without error
1344*da0073e9SAndroid Build Coastguard Worker    if len(outputs_to_check[0]) > 1:
1345*da0073e9SAndroid Build Coastguard Worker        for undef_grad_idx in range(len(outputs)):
1346*da0073e9SAndroid Build Coastguard Worker            output_to_check = _differentiable_outputs(func(*inputs))
1347*da0073e9SAndroid Build Coastguard Worker            outputs_to_check.append(
1348*da0073e9SAndroid Build Coastguard Worker                [
1349*da0073e9SAndroid Build Coastguard Worker                    torch._C._functions.UndefinedGrad()(o)
1350*da0073e9SAndroid Build Coastguard Worker                    if idx == undef_grad_idx
1351*da0073e9SAndroid Build Coastguard Worker                    else o
1352*da0073e9SAndroid Build Coastguard Worker                    for idx, o in enumerate(output_to_check)
1353*da0073e9SAndroid Build Coastguard Worker                ]
1354*da0073e9SAndroid Build Coastguard Worker            )
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker    return all(check_undefined_grad_support(output) for output in outputs_to_check)
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(x):
1360*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, tuple):
1361*da0073e9SAndroid Build Coastguard Worker        return x
1362*da0073e9SAndroid Build Coastguard Worker    elif isinstance(x, list):
1363*da0073e9SAndroid Build Coastguard Worker        return tuple(x)
1364*da0073e9SAndroid Build Coastguard Worker    else:
1365*da0073e9SAndroid Build Coastguard Worker        return (x,)
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Workerdef _differentiable_outputs(x):
1369*da0073e9SAndroid Build Coastguard Worker    return tuple(o for o in _as_tuple(x) if o.requires_grad)
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Workerdef _get_notallclose_msg(
1373*da0073e9SAndroid Build Coastguard Worker    analytical,
1374*da0073e9SAndroid Build Coastguard Worker    numerical,
1375*da0073e9SAndroid Build Coastguard Worker    output_idx,
1376*da0073e9SAndroid Build Coastguard Worker    input_idx,
1377*da0073e9SAndroid Build Coastguard Worker    complex_indices,
1378*da0073e9SAndroid Build Coastguard Worker    test_imag=False,
1379*da0073e9SAndroid Build Coastguard Worker    is_forward_ad=False,
1380*da0073e9SAndroid Build Coastguard Worker) -> str:
1381*da0073e9SAndroid Build Coastguard Worker    out_is_complex = (
1382*da0073e9SAndroid Build Coastguard Worker        (not is_forward_ad) and complex_indices and output_idx in complex_indices
1383*da0073e9SAndroid Build Coastguard Worker    )
1384*da0073e9SAndroid Build Coastguard Worker    inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices
1385*da0073e9SAndroid Build Coastguard Worker    part = "imaginary" if test_imag else "real"
1386*da0073e9SAndroid Build Coastguard Worker    element = "inputs" if is_forward_ad else "outputs"
1387*da0073e9SAndroid Build Coastguard Worker    prefix = (
1388*da0073e9SAndroid Build Coastguard Worker        ""
1389*da0073e9SAndroid Build Coastguard Worker        if not (out_is_complex or inp_is_complex)
1390*da0073e9SAndroid Build Coastguard Worker        else f"While considering the {part} part of complex {element} only, "
1391*da0073e9SAndroid Build Coastguard Worker    )
1392*da0073e9SAndroid Build Coastguard Worker    mode = "computed with forward mode " if is_forward_ad else ""
1393*da0073e9SAndroid Build Coastguard Worker    return (
1394*da0073e9SAndroid Build Coastguard Worker        prefix + "Jacobian %smismatch for output %d with respect to input %d,\n"
1395*da0073e9SAndroid Build Coastguard Worker        "numerical:%s\nanalytical:%s\n"
1396*da0073e9SAndroid Build Coastguard Worker        % (mode, output_idx, input_idx, numerical, analytical)
1397*da0073e9SAndroid Build Coastguard Worker    )
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Workerdef _transpose(matrix_of_tensors):
1401*da0073e9SAndroid Build Coastguard Worker    # returns list of tuples
1402*da0073e9SAndroid Build Coastguard Worker    return list(zip(*matrix_of_tensors))
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Workerdef _real_and_imag_output(fn):
1406*da0073e9SAndroid Build Coastguard Worker    # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as
1407*da0073e9SAndroid Build Coastguard Worker    # the original fn, except torch.real or torch.imag are applied to the complex outputs
1408*da0073e9SAndroid Build Coastguard Worker    def apply_to_c_outs(fn, fn_to_apply):
1409*da0073e9SAndroid Build Coastguard Worker        def wrapped_fn(*inputs):
1410*da0073e9SAndroid Build Coastguard Worker            outs = _as_tuple(fn(*inputs))
1411*da0073e9SAndroid Build Coastguard Worker            return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        return wrapped_fn
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker    return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Workerdef _real_and_imag_input(fn, complex_inp_indices, tupled_inputs):
1419*da0073e9SAndroid Build Coastguard Worker    # returns new functions that take real inputs instead of complex inputs as
1420*da0073e9SAndroid Build Coastguard Worker    # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j).
1421*da0073e9SAndroid Build Coastguard Worker    # In each case, the other part is considered constant.
1422*da0073e9SAndroid Build Coastguard Worker    # We do not use 0 for the constant here to make sure we always call the user function with a valid input.
1423*da0073e9SAndroid Build Coastguard Worker    def apply_to_c_inps(fn, fn_to_apply):
1424*da0073e9SAndroid Build Coastguard Worker        def wrapped_fn(*inputs):
1425*da0073e9SAndroid Build Coastguard Worker            new_inputs = list(inputs)
1426*da0073e9SAndroid Build Coastguard Worker            for should_be_complex in complex_inp_indices:
1427*da0073e9SAndroid Build Coastguard Worker                new_inputs[should_be_complex] = fn_to_apply(
1428*da0073e9SAndroid Build Coastguard Worker                    new_inputs[should_be_complex], tupled_inputs[should_be_complex]
1429*da0073e9SAndroid Build Coastguard Worker                )
1430*da0073e9SAndroid Build Coastguard Worker            return _as_tuple(fn(*new_inputs))
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        return wrapped_fn
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker    real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j)
1435*da0073e9SAndroid Build Coastguard Worker    imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j)
1436*da0073e9SAndroid Build Coastguard Worker    return real_fn, imag_fn
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Workerdef _gradcheck_real_imag(
1440*da0073e9SAndroid Build Coastguard Worker    gradcheck_fn,
1441*da0073e9SAndroid Build Coastguard Worker    func,
1442*da0073e9SAndroid Build Coastguard Worker    func_out,
1443*da0073e9SAndroid Build Coastguard Worker    tupled_inputs,
1444*da0073e9SAndroid Build Coastguard Worker    outputs,
1445*da0073e9SAndroid Build Coastguard Worker    eps,
1446*da0073e9SAndroid Build Coastguard Worker    rtol,
1447*da0073e9SAndroid Build Coastguard Worker    atol,
1448*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes,
1449*da0073e9SAndroid Build Coastguard Worker    check_forward_ad,
1450*da0073e9SAndroid Build Coastguard Worker    check_backward_ad,
1451*da0073e9SAndroid Build Coastguard Worker    nondet_tol,
1452*da0073e9SAndroid Build Coastguard Worker    check_undefined_grad,
1453*da0073e9SAndroid Build Coastguard Worker):
1454*da0073e9SAndroid Build Coastguard Worker    complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()]
1455*da0073e9SAndroid Build Coastguard Worker    has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out))
1456*da0073e9SAndroid Build Coastguard Worker    if check_backward_ad:
1457*da0073e9SAndroid Build Coastguard Worker        if has_any_complex_output:
1458*da0073e9SAndroid Build Coastguard Worker            real_fn, imag_fn = _real_and_imag_output(func)
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker            imag_func_out = imag_fn(*tupled_inputs)
1461*da0073e9SAndroid Build Coastguard Worker            imag_outputs = _differentiable_outputs(imag_func_out)
1462*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1463*da0073e9SAndroid Build Coastguard Worker                imag_fn,
1464*da0073e9SAndroid Build Coastguard Worker                imag_func_out,
1465*da0073e9SAndroid Build Coastguard Worker                tupled_inputs,
1466*da0073e9SAndroid Build Coastguard Worker                imag_outputs,
1467*da0073e9SAndroid Build Coastguard Worker                eps,
1468*da0073e9SAndroid Build Coastguard Worker                rtol,
1469*da0073e9SAndroid Build Coastguard Worker                atol,
1470*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1471*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1472*da0073e9SAndroid Build Coastguard Worker                complex_indices=complex_out_indices,
1473*da0073e9SAndroid Build Coastguard Worker                test_imag=True,
1474*da0073e9SAndroid Build Coastguard Worker            )
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker            real_func_out = real_fn(*tupled_inputs)
1477*da0073e9SAndroid Build Coastguard Worker            real_outputs = _differentiable_outputs(real_func_out)
1478*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1479*da0073e9SAndroid Build Coastguard Worker                real_fn,
1480*da0073e9SAndroid Build Coastguard Worker                real_func_out,
1481*da0073e9SAndroid Build Coastguard Worker                tupled_inputs,
1482*da0073e9SAndroid Build Coastguard Worker                real_outputs,
1483*da0073e9SAndroid Build Coastguard Worker                eps,
1484*da0073e9SAndroid Build Coastguard Worker                rtol,
1485*da0073e9SAndroid Build Coastguard Worker                atol,
1486*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1487*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1488*da0073e9SAndroid Build Coastguard Worker                complex_indices=complex_out_indices,
1489*da0073e9SAndroid Build Coastguard Worker            )
1490*da0073e9SAndroid Build Coastguard Worker        else:
1491*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1492*da0073e9SAndroid Build Coastguard Worker                func,
1493*da0073e9SAndroid Build Coastguard Worker                func_out,
1494*da0073e9SAndroid Build Coastguard Worker                tupled_inputs,
1495*da0073e9SAndroid Build Coastguard Worker                outputs,
1496*da0073e9SAndroid Build Coastguard Worker                eps,
1497*da0073e9SAndroid Build Coastguard Worker                rtol,
1498*da0073e9SAndroid Build Coastguard Worker                atol,
1499*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1500*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1501*da0073e9SAndroid Build Coastguard Worker            )
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker    if check_forward_ad:
1504*da0073e9SAndroid Build Coastguard Worker        complex_inp_indices = [
1505*da0073e9SAndroid Build Coastguard Worker            i
1506*da0073e9SAndroid Build Coastguard Worker            for i, inp in enumerate(tupled_inputs)
1507*da0073e9SAndroid Build Coastguard Worker            if is_tensor_like(inp) and inp.is_complex()
1508*da0073e9SAndroid Build Coastguard Worker        ]
1509*da0073e9SAndroid Build Coastguard Worker        if complex_inp_indices:
1510*da0073e9SAndroid Build Coastguard Worker            real_fn, imag_fn = _real_and_imag_input(
1511*da0073e9SAndroid Build Coastguard Worker                func, complex_inp_indices, tupled_inputs
1512*da0073e9SAndroid Build Coastguard Worker            )
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker            imag_inputs = [
1515*da0073e9SAndroid Build Coastguard Worker                inp.imag if is_tensor_like(inp) and inp.is_complex() else inp
1516*da0073e9SAndroid Build Coastguard Worker                for inp in tupled_inputs
1517*da0073e9SAndroid Build Coastguard Worker            ]
1518*da0073e9SAndroid Build Coastguard Worker            imag_func_out = imag_fn(*imag_inputs)
1519*da0073e9SAndroid Build Coastguard Worker            diff_imag_func_out = _differentiable_outputs(imag_func_out)
1520*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1521*da0073e9SAndroid Build Coastguard Worker                imag_fn,
1522*da0073e9SAndroid Build Coastguard Worker                imag_func_out,
1523*da0073e9SAndroid Build Coastguard Worker                imag_inputs,
1524*da0073e9SAndroid Build Coastguard Worker                diff_imag_func_out,
1525*da0073e9SAndroid Build Coastguard Worker                eps,
1526*da0073e9SAndroid Build Coastguard Worker                rtol,
1527*da0073e9SAndroid Build Coastguard Worker                atol,
1528*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1529*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1530*da0073e9SAndroid Build Coastguard Worker                complex_indices=complex_inp_indices,
1531*da0073e9SAndroid Build Coastguard Worker                test_imag=True,
1532*da0073e9SAndroid Build Coastguard Worker                use_forward_ad=True,
1533*da0073e9SAndroid Build Coastguard Worker            )
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker            real_inputs = [
1536*da0073e9SAndroid Build Coastguard Worker                inp.real if is_tensor_like(inp) and inp.is_complex() else inp
1537*da0073e9SAndroid Build Coastguard Worker                for inp in tupled_inputs
1538*da0073e9SAndroid Build Coastguard Worker            ]
1539*da0073e9SAndroid Build Coastguard Worker            real_func_out = real_fn(*real_inputs)
1540*da0073e9SAndroid Build Coastguard Worker            diff_real_func_out = _differentiable_outputs(real_func_out)
1541*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1542*da0073e9SAndroid Build Coastguard Worker                real_fn,
1543*da0073e9SAndroid Build Coastguard Worker                real_func_out,
1544*da0073e9SAndroid Build Coastguard Worker                real_inputs,
1545*da0073e9SAndroid Build Coastguard Worker                diff_real_func_out,
1546*da0073e9SAndroid Build Coastguard Worker                eps,
1547*da0073e9SAndroid Build Coastguard Worker                rtol,
1548*da0073e9SAndroid Build Coastguard Worker                atol,
1549*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1550*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1551*da0073e9SAndroid Build Coastguard Worker                complex_indices=complex_inp_indices,
1552*da0073e9SAndroid Build Coastguard Worker                use_forward_ad=True,
1553*da0073e9SAndroid Build Coastguard Worker            )
1554*da0073e9SAndroid Build Coastguard Worker            if check_undefined_grad:
1555*da0073e9SAndroid Build Coastguard Worker                _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs)
1556*da0073e9SAndroid Build Coastguard Worker                _test_undefined_forward_mode(real_fn, real_func_out, real_inputs)
1557*da0073e9SAndroid Build Coastguard Worker        else:
1558*da0073e9SAndroid Build Coastguard Worker            gradcheck_fn(
1559*da0073e9SAndroid Build Coastguard Worker                func,
1560*da0073e9SAndroid Build Coastguard Worker                func_out,
1561*da0073e9SAndroid Build Coastguard Worker                tupled_inputs,
1562*da0073e9SAndroid Build Coastguard Worker                outputs,
1563*da0073e9SAndroid Build Coastguard Worker                eps,
1564*da0073e9SAndroid Build Coastguard Worker                rtol,
1565*da0073e9SAndroid Build Coastguard Worker                atol,
1566*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes,
1567*da0073e9SAndroid Build Coastguard Worker                nondet_tol,
1568*da0073e9SAndroid Build Coastguard Worker                use_forward_ad=True,
1569*da0073e9SAndroid Build Coastguard Worker            )
1570*da0073e9SAndroid Build Coastguard Worker            if check_undefined_grad:
1571*da0073e9SAndroid Build Coastguard Worker                _test_undefined_forward_mode(func, outputs, tupled_inputs)
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Worker
1574*da0073e9SAndroid Build Coastguard Workerdef _slow_gradcheck(
1575*da0073e9SAndroid Build Coastguard Worker    func,
1576*da0073e9SAndroid Build Coastguard Worker    func_out,
1577*da0073e9SAndroid Build Coastguard Worker    tupled_inputs,
1578*da0073e9SAndroid Build Coastguard Worker    outputs,
1579*da0073e9SAndroid Build Coastguard Worker    eps,
1580*da0073e9SAndroid Build Coastguard Worker    rtol,
1581*da0073e9SAndroid Build Coastguard Worker    atol,
1582*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes,
1583*da0073e9SAndroid Build Coastguard Worker    nondet_tol,
1584*da0073e9SAndroid Build Coastguard Worker    *,
1585*da0073e9SAndroid Build Coastguard Worker    use_forward_ad=False,
1586*da0073e9SAndroid Build Coastguard Worker    complex_indices=None,
1587*da0073e9SAndroid Build Coastguard Worker    test_imag=False,
1588*da0073e9SAndroid Build Coastguard Worker    masked=False,
1589*da0073e9SAndroid Build Coastguard Worker):
1590*da0073e9SAndroid Build Coastguard Worker    func_out = _as_tuple(func_out)
1591*da0073e9SAndroid Build Coastguard Worker    if not outputs:
1592*da0073e9SAndroid Build Coastguard Worker        return _check_no_differentiable_outputs(
1593*da0073e9SAndroid Build Coastguard Worker            func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad
1594*da0073e9SAndroid Build Coastguard Worker        )
1595*da0073e9SAndroid Build Coastguard Worker    tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker    numerical = _transpose(
1598*da0073e9SAndroid Build Coastguard Worker        _get_numerical_jacobian(
1599*da0073e9SAndroid Build Coastguard Worker            func,
1600*da0073e9SAndroid Build Coastguard Worker            tupled_inputs_numerical,
1601*da0073e9SAndroid Build Coastguard Worker            func_out,
1602*da0073e9SAndroid Build Coastguard Worker            eps=eps,
1603*da0073e9SAndroid Build Coastguard Worker            is_forward_ad=use_forward_ad,
1604*da0073e9SAndroid Build Coastguard Worker        )
1605*da0073e9SAndroid Build Coastguard Worker    )
1606*da0073e9SAndroid Build Coastguard Worker    # Note: [numerical vs analytical output length]
1607*da0073e9SAndroid Build Coastguard Worker    # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that
1608*da0073e9SAndroid Build Coastguard Worker    # output is False. This behavior is necessary for _check_no_differentiable_outputs to work.
1609*da0073e9SAndroid Build Coastguard Worker    numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad]
1610*da0073e9SAndroid Build Coastguard Worker    if use_forward_ad:
1611*da0073e9SAndroid Build Coastguard Worker        analytical_forward = _get_analytical_jacobian_forward_ad(
1612*da0073e9SAndroid Build Coastguard Worker            func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes
1613*da0073e9SAndroid Build Coastguard Worker        )
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        for i, n_per_out in enumerate(numerical):
1616*da0073e9SAndroid Build Coastguard Worker            for j, n in enumerate(n_per_out):
1617*da0073e9SAndroid Build Coastguard Worker                a = analytical_forward[j][i]
1618*da0073e9SAndroid Build Coastguard Worker                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
1619*da0073e9SAndroid Build Coastguard Worker                    raise GradcheckError(
1620*da0073e9SAndroid Build Coastguard Worker                        _get_notallclose_msg(
1621*da0073e9SAndroid Build Coastguard Worker                            a, n, i, j, complex_indices, test_imag, is_forward_ad=True
1622*da0073e9SAndroid Build Coastguard Worker                        )
1623*da0073e9SAndroid Build Coastguard Worker                    )
1624*da0073e9SAndroid Build Coastguard Worker    else:
1625*da0073e9SAndroid Build Coastguard Worker        for i, o in enumerate(outputs):
1626*da0073e9SAndroid Build Coastguard Worker            analytical = _check_analytical_jacobian_attributes(
1627*da0073e9SAndroid Build Coastguard Worker                tupled_inputs, o, nondet_tol, check_grad_dtypes
1628*da0073e9SAndroid Build Coastguard Worker            )
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker            for j, (a, n) in enumerate(zip(analytical, numerical[i])):
1631*da0073e9SAndroid Build Coastguard Worker                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
1632*da0073e9SAndroid Build Coastguard Worker                    raise GradcheckError(
1633*da0073e9SAndroid Build Coastguard Worker                        _get_notallclose_msg(a, n, i, j, complex_indices, test_imag)
1634*da0073e9SAndroid Build Coastguard Worker                    )
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    return True
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Workerdef _dot_with_type_promotion(u, v):
1640*da0073e9SAndroid Build Coastguard Worker    assert u.dim() == 1 and v.dim() == 1
1641*da0073e9SAndroid Build Coastguard Worker    return (u * v).sum()
1642*da0073e9SAndroid Build Coastguard Worker
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Workerdef _allclose_with_type_promotion(a, b, rtol, atol):
1645*da0073e9SAndroid Build Coastguard Worker    promoted_type = torch.promote_types(a.dtype, b.dtype)
1646*da0073e9SAndroid Build Coastguard Worker    a = a.to(dtype=promoted_type)
1647*da0073e9SAndroid Build Coastguard Worker    b = b.to(dtype=promoted_type)
1648*da0073e9SAndroid Build Coastguard Worker    return torch.allclose(a, b, rtol, atol)
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker
1651*da0073e9SAndroid Build Coastguard Workerdef _to_real_dtype(dtype):
1652*da0073e9SAndroid Build Coastguard Worker    if dtype == torch.complex128:
1653*da0073e9SAndroid Build Coastguard Worker        return torch.float64
1654*da0073e9SAndroid Build Coastguard Worker    elif dtype == torch.complex64:
1655*da0073e9SAndroid Build Coastguard Worker        return torch.float32
1656*da0073e9SAndroid Build Coastguard Worker    else:
1657*da0073e9SAndroid Build Coastguard Worker        return dtype
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker
1660*da0073e9SAndroid Build Coastguard Workerdef _vec_from_tensor(x, generator, downcast_complex=False):
1661*da0073e9SAndroid Build Coastguard Worker    # Create a random vector with the same number of elements as x and the same
1662*da0073e9SAndroid Build Coastguard Worker    # dtype/device. If x is complex and downcast_complex is False, we create a
1663*da0073e9SAndroid Build Coastguard Worker    # complex tensor with only real component.
1664*da0073e9SAndroid Build Coastguard Worker    if x.layout == torch.sparse_coo:
1665*da0073e9SAndroid Build Coastguard Worker        # For sparse, create a random sparse vec with random values in the same
1666*da0073e9SAndroid Build Coastguard Worker        # indices. Make sure size is set so that it isn't inferred to be smaller.
1667*da0073e9SAndroid Build Coastguard Worker        x_values = x._values()
1668*da0073e9SAndroid Build Coastguard Worker        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1669*da0073e9SAndroid Build Coastguard Worker        values = (
1670*da0073e9SAndroid Build Coastguard Worker            torch.rand(x_values.numel(), generator=generator)
1671*da0073e9SAndroid Build Coastguard Worker            .to(dtype=dtype, device=x.device)
1672*da0073e9SAndroid Build Coastguard Worker            .view(x_values.shape)
1673*da0073e9SAndroid Build Coastguard Worker        )
1674*da0073e9SAndroid Build Coastguard Worker        values /= values.norm()
1675*da0073e9SAndroid Build Coastguard Worker        vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device)
1676*da0073e9SAndroid Build Coastguard Worker    elif _is_sparse_compressed_tensor(x):
1677*da0073e9SAndroid Build Coastguard Worker        if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
1678*da0073e9SAndroid Build Coastguard Worker            compressed_indices, plain_indices = x.crow_indices(), x.col_indices()
1679*da0073e9SAndroid Build Coastguard Worker        else:
1680*da0073e9SAndroid Build Coastguard Worker            compressed_indices, plain_indices = x.ccol_indices(), x.row_indices()
1681*da0073e9SAndroid Build Coastguard Worker        x_values = x.values()
1682*da0073e9SAndroid Build Coastguard Worker        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1683*da0073e9SAndroid Build Coastguard Worker        values = (
1684*da0073e9SAndroid Build Coastguard Worker            torch.rand(x_values.numel(), generator=generator)
1685*da0073e9SAndroid Build Coastguard Worker            .to(dtype=dtype, device=x.device)
1686*da0073e9SAndroid Build Coastguard Worker            .view(x_values.shape)
1687*da0073e9SAndroid Build Coastguard Worker        )
1688*da0073e9SAndroid Build Coastguard Worker        values /= values.norm()
1689*da0073e9SAndroid Build Coastguard Worker        vec = torch.sparse_compressed_tensor(
1690*da0073e9SAndroid Build Coastguard Worker            compressed_indices,
1691*da0073e9SAndroid Build Coastguard Worker            plain_indices,
1692*da0073e9SAndroid Build Coastguard Worker            values,
1693*da0073e9SAndroid Build Coastguard Worker            x.size(),
1694*da0073e9SAndroid Build Coastguard Worker            layout=x.layout,
1695*da0073e9SAndroid Build Coastguard Worker            device=x.device,
1696*da0073e9SAndroid Build Coastguard Worker        )
1697*da0073e9SAndroid Build Coastguard Worker    else:
1698*da0073e9SAndroid Build Coastguard Worker        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
1699*da0073e9SAndroid Build Coastguard Worker        vec = torch.rand(x.numel(), generator=generator).to(
1700*da0073e9SAndroid Build Coastguard Worker            dtype=dtype, device=x.device
1701*da0073e9SAndroid Build Coastguard Worker        )
1702*da0073e9SAndroid Build Coastguard Worker        vec /= vec.norm()
1703*da0073e9SAndroid Build Coastguard Worker    return vec
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Workerdef _get_inp_tensors(tupled_inputs):
1707*da0073e9SAndroid Build Coastguard Worker    inp_idx_tup = [
1708*da0073e9SAndroid Build Coastguard Worker        (i, t)
1709*da0073e9SAndroid Build Coastguard Worker        for i, t in enumerate(tupled_inputs)
1710*da0073e9SAndroid Build Coastguard Worker        if is_tensor_like(t) and t.requires_grad
1711*da0073e9SAndroid Build Coastguard Worker    ]
1712*da0073e9SAndroid Build Coastguard Worker    return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup]
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Workerdef _adjusted_atol(atol, u, v):
1716*da0073e9SAndroid Build Coastguard Worker    # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we
1717*da0073e9SAndroid Build Coastguard Worker    # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and
1718*da0073e9SAndroid Build Coastguard Worker    # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is
1719*da0073e9SAndroid Build Coastguard Worker    # the correctly sized matrix in which each entry is atol.
1720*da0073e9SAndroid Build Coastguard Worker    #
1721*da0073e9SAndroid Build Coastguard Worker    # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N
1722*da0073e9SAndroid Build Coastguard Worker    # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i)
1723*da0073e9SAndroid Build Coastguard Worker    # TODO: properly handle case when u is tuple instead of only taking first element
1724*da0073e9SAndroid Build Coastguard Worker    u = u[0] if isinstance(u, tuple) else u
1725*da0073e9SAndroid Build Coastguard Worker    sum_u = u.sum()
1726*da0073e9SAndroid Build Coastguard Worker    sum_v = 1.0 if v is None else v.sum()
1727*da0073e9SAndroid Build Coastguard Worker    return atol * float(sum_u) * float(sum_v)
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard WorkerFAST_FAIL_SLOW_OK_MSG = """
1731*da0073e9SAndroid Build Coastguard WorkerFast gradcheck failed but element-wise differences are small. This means that the
1732*da0073e9SAndroid Build Coastguard Workertest might've passed in slow_mode!
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the
1735*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck:
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard WorkerIf the test
1738*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
1739*da0073e9SAndroid Build Coastguard Worker  with `fast_mode=False` as a keyword argument.
1740*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
1741*da0073e9SAndroid Build Coastguard Worker  to have `gradcheck_fast_mode=False`
1742*da0073e9SAndroid Build Coastguard Worker- is a Module test (e.g., in common_nn.py), then modify the corresponding
1743*da0073e9SAndroid Build Coastguard Worker  module_test entry to have `gradcheck_fast_mode=False`
1744*da0073e9SAndroid Build Coastguard Worker""".strip()
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Workerdef _run_slow_mode_and_get_error(
1748*da0073e9SAndroid Build Coastguard Worker    func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad
1749*da0073e9SAndroid Build Coastguard Worker):
1750*da0073e9SAndroid Build Coastguard Worker    # Compute jacobians in slow mode for better error message
1751*da0073e9SAndroid Build Coastguard Worker    slow_numerical = _get_numerical_jacobian(
1752*da0073e9SAndroid Build Coastguard Worker        func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad
1753*da0073e9SAndroid Build Coastguard Worker    )[input_idx][output_idx]
1754*da0073e9SAndroid Build Coastguard Worker    if is_forward_ad:
1755*da0073e9SAndroid Build Coastguard Worker
1756*da0073e9SAndroid Build Coastguard Worker        def new_fn(inp):
1757*da0073e9SAndroid Build Coastguard Worker            new_inputs = list(tupled_inputs)
1758*da0073e9SAndroid Build Coastguard Worker            new_inputs[input_idx] = inp
1759*da0073e9SAndroid Build Coastguard Worker            return _as_tuple(func(*new_inputs))[output_idx]
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker        slow_analytical = _get_analytical_jacobian_forward_ad(
1762*da0073e9SAndroid Build Coastguard Worker            new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],)
1763*da0073e9SAndroid Build Coastguard Worker        )[0][0]
1764*da0073e9SAndroid Build Coastguard Worker    else:
1765*da0073e9SAndroid Build Coastguard Worker        slow_analytical = _get_analytical_jacobian(
1766*da0073e9SAndroid Build Coastguard Worker            tupled_inputs, outputs, input_idx, output_idx
1767*da0073e9SAndroid Build Coastguard Worker        )
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    # Assume jacobians are non-empty and have the same shape
1770*da0073e9SAndroid Build Coastguard Worker    slow_max_diff = (slow_numerical - slow_analytical).abs().max()
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker    slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol)
1773*da0073e9SAndroid Build Coastguard Worker    msg = (
1774*da0073e9SAndroid Build Coastguard Worker        "\nThe above quantities relating the numerical and analytical jacobians are computed \n"
1775*da0073e9SAndroid Build Coastguard Worker        "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n"
1776*da0073e9SAndroid Build Coastguard Worker        "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n"
1777*da0073e9SAndroid Build Coastguard Worker        f"Numerical:\n {slow_numerical}\n"
1778*da0073e9SAndroid Build Coastguard Worker        f"Analytical:\n{slow_analytical}\n\n"
1779*da0073e9SAndroid Build Coastguard Worker        f"The max per-element difference (slow mode) is: {slow_max_diff}.\n"
1780*da0073e9SAndroid Build Coastguard Worker    )
1781*da0073e9SAndroid Build Coastguard Worker    if slow_allclose:
1782*da0073e9SAndroid Build Coastguard Worker        # Slow gradcheck would've passed!
1783*da0073e9SAndroid Build Coastguard Worker        msg += FAST_FAIL_SLOW_OK_MSG
1784*da0073e9SAndroid Build Coastguard Worker    return msg
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Workerdef _to_flat_dense_if_sparse(tensor):
1788*da0073e9SAndroid Build Coastguard Worker    if _is_sparse_any_tensor(tensor):
1789*da0073e9SAndroid Build Coastguard Worker        return tensor.to_dense().reshape(-1)
1790*da0073e9SAndroid Build Coastguard Worker    else:
1791*da0073e9SAndroid Build Coastguard Worker        return tensor
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Workerdef _make_vectors(inp_tensors, outputs, *, use_forward_ad):
1795*da0073e9SAndroid Build Coastguard Worker    # Use our own generator to avoid messing with the user's RNG state
1796*da0073e9SAndroid Build Coastguard Worker    g_cpu = torch.Generator()
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker    def _vec_from_tensor_cpu(*args):
1799*da0073e9SAndroid Build Coastguard Worker        # Default allocate all tensors on CPU, so they are on the same device as the generator
1800*da0073e9SAndroid Build Coastguard Worker        # even if the user specified a default device
1801*da0073e9SAndroid Build Coastguard Worker        with torch.device("cpu"):
1802*da0073e9SAndroid Build Coastguard Worker            return _vec_from_tensor(*args)
1803*da0073e9SAndroid Build Coastguard Worker
1804*da0073e9SAndroid Build Coastguard Worker    all_u = []
1805*da0073e9SAndroid Build Coastguard Worker    all_u_dense = []
1806*da0073e9SAndroid Build Coastguard Worker    for inp in inp_tensors:
1807*da0073e9SAndroid Build Coastguard Worker        ur = _vec_from_tensor_cpu(inp, g_cpu, True)
1808*da0073e9SAndroid Build Coastguard Worker        ur_dense = _to_flat_dense_if_sparse(ur)
1809*da0073e9SAndroid Build Coastguard Worker        if inp.is_complex():
1810*da0073e9SAndroid Build Coastguard Worker            ui = _vec_from_tensor_cpu(inp, g_cpu, True)
1811*da0073e9SAndroid Build Coastguard Worker            all_u.append((ur, ui))
1812*da0073e9SAndroid Build Coastguard Worker            ui_dense = _to_flat_dense_if_sparse(ui)
1813*da0073e9SAndroid Build Coastguard Worker            all_u_dense.append((ur_dense, ui_dense))
1814*da0073e9SAndroid Build Coastguard Worker        else:
1815*da0073e9SAndroid Build Coastguard Worker            all_u.append(ur)
1816*da0073e9SAndroid Build Coastguard Worker            all_u_dense.append(ur_dense)
1817*da0073e9SAndroid Build Coastguard Worker    all_v = (
1818*da0073e9SAndroid Build Coastguard Worker        None
1819*da0073e9SAndroid Build Coastguard Worker        if use_forward_ad
1820*da0073e9SAndroid Build Coastguard Worker        else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs]
1821*da0073e9SAndroid Build Coastguard Worker    )
1822*da0073e9SAndroid Build Coastguard Worker    return all_v, all_u, all_u_dense
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Workerdef _check_analytical_numerical_equal(
1826*da0073e9SAndroid Build Coastguard Worker    all_analytical,
1827*da0073e9SAndroid Build Coastguard Worker    all_numerical,
1828*da0073e9SAndroid Build Coastguard Worker    complex_indices,
1829*da0073e9SAndroid Build Coastguard Worker    tupled_inputs,
1830*da0073e9SAndroid Build Coastguard Worker    outputs,
1831*da0073e9SAndroid Build Coastguard Worker    func,
1832*da0073e9SAndroid Build Coastguard Worker    all_v,
1833*da0073e9SAndroid Build Coastguard Worker    all_u,
1834*da0073e9SAndroid Build Coastguard Worker    rtol,
1835*da0073e9SAndroid Build Coastguard Worker    atol,
1836*da0073e9SAndroid Build Coastguard Worker    eps,
1837*da0073e9SAndroid Build Coastguard Worker    test_imag,
1838*da0073e9SAndroid Build Coastguard Worker    *,
1839*da0073e9SAndroid Build Coastguard Worker    is_forward_ad=False,
1840*da0073e9SAndroid Build Coastguard Worker):
1841*da0073e9SAndroid Build Coastguard Worker    for i, all_numerical_for_input_i in enumerate(all_numerical):
1842*da0073e9SAndroid Build Coastguard Worker        for j, n in enumerate(all_numerical_for_input_i):
1843*da0073e9SAndroid Build Coastguard Worker            # Forward AD generates the transpose of what this function expects
1844*da0073e9SAndroid Build Coastguard Worker            if is_forward_ad:
1845*da0073e9SAndroid Build Coastguard Worker                a = all_analytical[i][j]
1846*da0073e9SAndroid Build Coastguard Worker            else:
1847*da0073e9SAndroid Build Coastguard Worker                a = all_analytical[j][i]
1848*da0073e9SAndroid Build Coastguard Worker            n = n.to(device=a.device)
1849*da0073e9SAndroid Build Coastguard Worker            updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None)
1850*da0073e9SAndroid Build Coastguard Worker            if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol):
1851*da0073e9SAndroid Build Coastguard Worker                jacobians_str = _run_slow_mode_and_get_error(
1852*da0073e9SAndroid Build Coastguard Worker                    func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad
1853*da0073e9SAndroid Build Coastguard Worker                )
1854*da0073e9SAndroid Build Coastguard Worker                raise GradcheckError(
1855*da0073e9SAndroid Build Coastguard Worker                    _get_notallclose_msg(
1856*da0073e9SAndroid Build Coastguard Worker                        a, n, j, i, complex_indices, test_imag, is_forward_ad
1857*da0073e9SAndroid Build Coastguard Worker                    )
1858*da0073e9SAndroid Build Coastguard Worker                    + jacobians_str
1859*da0073e9SAndroid Build Coastguard Worker                )
1860*da0073e9SAndroid Build Coastguard Worker
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Workerdef _fast_gradcheck(
1863*da0073e9SAndroid Build Coastguard Worker    func,
1864*da0073e9SAndroid Build Coastguard Worker    func_out,
1865*da0073e9SAndroid Build Coastguard Worker    inputs,
1866*da0073e9SAndroid Build Coastguard Worker    outputs,
1867*da0073e9SAndroid Build Coastguard Worker    eps,
1868*da0073e9SAndroid Build Coastguard Worker    rtol,
1869*da0073e9SAndroid Build Coastguard Worker    atol,
1870*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes,
1871*da0073e9SAndroid Build Coastguard Worker    nondet_tol,
1872*da0073e9SAndroid Build Coastguard Worker    *,
1873*da0073e9SAndroid Build Coastguard Worker    use_forward_ad=False,
1874*da0073e9SAndroid Build Coastguard Worker    complex_indices=None,
1875*da0073e9SAndroid Build Coastguard Worker    test_imag=False,
1876*da0073e9SAndroid Build Coastguard Worker    masked=False,
1877*da0073e9SAndroid Build Coastguard Worker):
1878*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/53876 for details
1879*da0073e9SAndroid Build Coastguard Worker    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
1880*da0073e9SAndroid Build Coastguard Worker    # Backward mode computes v^T * J (VJP)
1881*da0073e9SAndroid Build Coastguard Worker    # Since we computed J * u (JVP) through finite difference method, we perform an equality check
1882*da0073e9SAndroid Build Coastguard Worker    # between VJP * u, v * JVP
1883*da0073e9SAndroid Build Coastguard Worker    # ----
1884*da0073e9SAndroid Build Coastguard Worker    # Forward mode computes J * u (JVP)
1885*da0073e9SAndroid Build Coastguard Worker    # Since we already compute JVP through finite difference method,
1886*da0073e9SAndroid Build Coastguard Worker    # we don't need v for correctness check here as asserted below
1887*da0073e9SAndroid Build Coastguard Worker    all_v, all_u, all_u_dense = _make_vectors(
1888*da0073e9SAndroid Build Coastguard Worker        inp_tensors, outputs, use_forward_ad=use_forward_ad
1889*da0073e9SAndroid Build Coastguard Worker    )
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Worker    inputs_numerical, all_u_numerical, all_v_numerical = (
1892*da0073e9SAndroid Build Coastguard Worker        (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v))
1893*da0073e9SAndroid Build Coastguard Worker    )
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker    numerical_vJu = _get_numerical_vJu(
1896*da0073e9SAndroid Build Coastguard Worker        func,
1897*da0073e9SAndroid Build Coastguard Worker        inputs_numerical,
1898*da0073e9SAndroid Build Coastguard Worker        inp_tensors_idx,
1899*da0073e9SAndroid Build Coastguard Worker        func_out,
1900*da0073e9SAndroid Build Coastguard Worker        all_u_numerical,
1901*da0073e9SAndroid Build Coastguard Worker        all_v_numerical,
1902*da0073e9SAndroid Build Coastguard Worker        eps,
1903*da0073e9SAndroid Build Coastguard Worker        is_forward_ad=use_forward_ad,
1904*da0073e9SAndroid Build Coastguard Worker    )
1905*da0073e9SAndroid Build Coastguard Worker    # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well
1906*da0073e9SAndroid Build Coastguard Worker    if use_forward_ad:
1907*da0073e9SAndroid Build Coastguard Worker        assert all_v is None
1908*da0073e9SAndroid Build Coastguard Worker        analytical_vJu = _get_analytical_jacobian_forward_ad(
1909*da0073e9SAndroid Build Coastguard Worker            func,
1910*da0073e9SAndroid Build Coastguard Worker            inputs,
1911*da0073e9SAndroid Build Coastguard Worker            _as_tuple(func_out),
1912*da0073e9SAndroid Build Coastguard Worker            all_u=all_u,
1913*da0073e9SAndroid Build Coastguard Worker            check_grad_dtypes=check_grad_dtypes,
1914*da0073e9SAndroid Build Coastguard Worker        )
1915*da0073e9SAndroid Build Coastguard Worker    else:
1916*da0073e9SAndroid Build Coastguard Worker        if not outputs:
1917*da0073e9SAndroid Build Coastguard Worker            _check_no_differentiable_outputs_fast(
1918*da0073e9SAndroid Build Coastguard Worker                func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol
1919*da0073e9SAndroid Build Coastguard Worker            )
1920*da0073e9SAndroid Build Coastguard Worker
1921*da0073e9SAndroid Build Coastguard Worker        analytical_vJu = _get_analytical_vJu_backward_mode(
1922*da0073e9SAndroid Build Coastguard Worker            inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense
1923*da0073e9SAndroid Build Coastguard Worker        )
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker    _check_analytical_numerical_equal(
1926*da0073e9SAndroid Build Coastguard Worker        analytical_vJu,
1927*da0073e9SAndroid Build Coastguard Worker        numerical_vJu,
1928*da0073e9SAndroid Build Coastguard Worker        complex_indices,
1929*da0073e9SAndroid Build Coastguard Worker        inputs,
1930*da0073e9SAndroid Build Coastguard Worker        outputs,
1931*da0073e9SAndroid Build Coastguard Worker        func,
1932*da0073e9SAndroid Build Coastguard Worker        all_v,
1933*da0073e9SAndroid Build Coastguard Worker        all_u,
1934*da0073e9SAndroid Build Coastguard Worker        rtol,
1935*da0073e9SAndroid Build Coastguard Worker        atol,
1936*da0073e9SAndroid Build Coastguard Worker        eps,
1937*da0073e9SAndroid Build Coastguard Worker        test_imag,
1938*da0073e9SAndroid Build Coastguard Worker        is_forward_ad=use_forward_ad,
1939*da0073e9SAndroid Build Coastguard Worker    )
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker    return True
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker# Note [VarArg of Tensors]
1945*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~
1946*da0073e9SAndroid Build Coastguard Worker# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
1947*da0073e9SAndroid Build Coastguard Worker# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
1948*da0073e9SAndroid Build Coastguard Worker# the '...' first argument of Callable can be replaced with VarArg(Tensor).
1949*da0073e9SAndroid Build Coastguard Worker# For now, we permit any input.
1950*da0073e9SAndroid Build Coastguard Workerdef gradcheck(
1951*da0073e9SAndroid Build Coastguard Worker    func: Callable[..., Union[_TensorOrTensors]],  # See Note [VarArg of Tensors]
1952*da0073e9SAndroid Build Coastguard Worker    inputs: _TensorOrTensors,
1953*da0073e9SAndroid Build Coastguard Worker    *,
1954*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-6,
1955*da0073e9SAndroid Build Coastguard Worker    atol: float = 1e-5,
1956*da0073e9SAndroid Build Coastguard Worker    rtol: float = 1e-3,
1957*da0073e9SAndroid Build Coastguard Worker    raise_exception: bool = True,
1958*da0073e9SAndroid Build Coastguard Worker    nondet_tol: float = 0.0,
1959*da0073e9SAndroid Build Coastguard Worker    check_undefined_grad: bool = True,
1960*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes: bool = False,
1961*da0073e9SAndroid Build Coastguard Worker    check_batched_grad: bool = False,
1962*da0073e9SAndroid Build Coastguard Worker    check_batched_forward_grad: bool = False,
1963*da0073e9SAndroid Build Coastguard Worker    check_forward_ad: bool = False,
1964*da0073e9SAndroid Build Coastguard Worker    check_backward_ad: bool = True,
1965*da0073e9SAndroid Build Coastguard Worker    fast_mode: bool = False,
1966*da0073e9SAndroid Build Coastguard Worker    masked: Optional[bool] = None,
1967*da0073e9SAndroid Build Coastguard Worker) -> bool:  # noqa: D400,D205
1968*da0073e9SAndroid Build Coastguard Worker    r"""Check gradients computed via small finite differences against analytical
1969*da0073e9SAndroid Build Coastguard Worker    gradients wrt tensors in :attr:`inputs` that are of floating point or complex type
1970*da0073e9SAndroid Build Coastguard Worker    and with ``requires_grad=True``.
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Worker    For most of the complex functions we consider for optimization purposes, no notion of
1975*da0073e9SAndroid Build Coastguard Worker    Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of
1976*da0073e9SAndroid Build Coastguard Worker    the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient
1977*da0073e9SAndroid Build Coastguard Worker    computation is done under the assumption that the overall function has a real-valued
1978*da0073e9SAndroid Build Coastguard Worker    output, we treat functions with complex output in a special way. For these functions,
1979*da0073e9SAndroid Build Coastguard Worker    gradcheck is applied to two real-valued functions corresponding to taking the real
1980*da0073e9SAndroid Build Coastguard Worker    components of the complex outputs for the first, and taking the imaginary components
1981*da0073e9SAndroid Build Coastguard Worker    of the complex outputs for the second. For more details, check out
1982*da0073e9SAndroid Build Coastguard Worker    :ref:`complex_autograd-doc`.
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker    .. note::
1985*da0073e9SAndroid Build Coastguard Worker        The default values are designed for :attr:`input` of double precision.
1986*da0073e9SAndroid Build Coastguard Worker        This check will likely fail if :attr:`input` is of less precision, e.g.,
1987*da0073e9SAndroid Build Coastguard Worker        ``FloatTensor``.
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker    .. note::
1990*da0073e9SAndroid Build Coastguard Worker        Gradcheck may fail when evaluated on non-differentiable points
1991*da0073e9SAndroid Build Coastguard Worker        because the numerically computed gradients via finite differencing may differ
1992*da0073e9SAndroid Build Coastguard Worker        those computed analytically (not necessarily because either is incorrect).
1993*da0073e9SAndroid Build Coastguard Worker        For more context, see :ref:`non-differentiable-func-grad`.
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker    .. warning::
1996*da0073e9SAndroid Build Coastguard Worker       If any checked tensor in :attr:`input` has overlapping memory, i.e.,
1997*da0073e9SAndroid Build Coastguard Worker       different indices pointing to the same memory address (e.g., from
1998*da0073e9SAndroid Build Coastguard Worker       :func:`torch.expand`), this check will likely fail because the numerical
1999*da0073e9SAndroid Build Coastguard Worker       gradients computed by point perturbation at such indices will change
2000*da0073e9SAndroid Build Coastguard Worker       values at all other indices that share the same memory address.
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker    Args:
2003*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
2004*da0073e9SAndroid Build Coastguard Worker            a Tensor or a tuple of Tensors
2005*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensor or Tensor): inputs to the function
2006*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): perturbation for finite differences
2007*da0073e9SAndroid Build Coastguard Worker        atol (float, optional): absolute tolerance
2008*da0073e9SAndroid Build Coastguard Worker        rtol (float, optional): relative tolerance
2009*da0073e9SAndroid Build Coastguard Worker        raise_exception (bool, optional): indicating whether to raise an exception if
2010*da0073e9SAndroid Build Coastguard Worker            the check fails. The exception gives more information about the
2011*da0073e9SAndroid Build Coastguard Worker            exact nature of the failure. This is helpful when debugging gradchecks.
2012*da0073e9SAndroid Build Coastguard Worker        nondet_tol (float, optional): tolerance for non-determinism. When running
2013*da0073e9SAndroid Build Coastguard Worker            identical inputs through the differentiation, the results must either match
2014*da0073e9SAndroid Build Coastguard Worker            exactly (default, 0.0) or be within this tolerance.
2015*da0073e9SAndroid Build Coastguard Worker        check_undefined_grad (bool, optional): if ``True``, check if undefined output grads
2016*da0073e9SAndroid Build Coastguard Worker            are supported and treated as zeros, for ``Tensor`` outputs.
2017*da0073e9SAndroid Build Coastguard Worker        check_batched_grad (bool, optional): if ``True``, check if we can compute
2018*da0073e9SAndroid Build Coastguard Worker            batched gradients using prototype vmap support. Defaults to False.
2019*da0073e9SAndroid Build Coastguard Worker        check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute
2020*da0073e9SAndroid Build Coastguard Worker            batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``.
2021*da0073e9SAndroid Build Coastguard Worker        check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward
2022*da0073e9SAndroid Build Coastguard Worker            mode AD match the numerical ones. Defaults to ``False``.
2023*da0073e9SAndroid Build Coastguard Worker        check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on
2024*da0073e9SAndroid Build Coastguard Worker            backward mode AD to be implemented. Defaults to ``True``.
2025*da0073e9SAndroid Build Coastguard Worker        fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only
2026*da0073e9SAndroid Build Coastguard Worker            implemented for R to R functions. If none of the inputs and outputs are complex
2027*da0073e9SAndroid Build Coastguard Worker            a faster implementation of gradcheck that no longer computes the entire jacobian
2028*da0073e9SAndroid Build Coastguard Worker            is run; otherwise, we fall back to the slow implementation.
2029*da0073e9SAndroid Build Coastguard Worker        masked (bool, optional): if ``True``, the gradients of unspecified elements of
2030*da0073e9SAndroid Build Coastguard Worker            sparse tensors are ignored. Defaults to ``False``.
2031*da0073e9SAndroid Build Coastguard Worker    Returns:
2032*da0073e9SAndroid Build Coastguard Worker        ``True`` if all differences satisfy allclose condition
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker    """
2035*da0073e9SAndroid Build Coastguard Worker    assert (
2036*da0073e9SAndroid Build Coastguard Worker        check_forward_ad or check_backward_ad
2037*da0073e9SAndroid Build Coastguard Worker    ), "Expected at least one of check_forward_ad or check_backward_ad to be True"
2038*da0073e9SAndroid Build Coastguard Worker    assert not (
2039*da0073e9SAndroid Build Coastguard Worker        check_batched_grad and not check_backward_ad
2040*da0073e9SAndroid Build Coastguard Worker    ), "Setting check_batched_grad=True requires check_backward_ad to be True"
2041*da0073e9SAndroid Build Coastguard Worker    assert not (
2042*da0073e9SAndroid Build Coastguard Worker        check_batched_forward_grad and not check_forward_ad
2043*da0073e9SAndroid Build Coastguard Worker    ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True"
2044*da0073e9SAndroid Build Coastguard Worker    args = locals().copy()
2045*da0073e9SAndroid Build Coastguard Worker    args.pop("raise_exception")
2046*da0073e9SAndroid Build Coastguard Worker    if not raise_exception:
2047*da0073e9SAndroid Build Coastguard Worker        try:
2048*da0073e9SAndroid Build Coastguard Worker            return _gradcheck_helper(**args)
2049*da0073e9SAndroid Build Coastguard Worker        except GradcheckError as e:
2050*da0073e9SAndroid Build Coastguard Worker            return False
2051*da0073e9SAndroid Build Coastguard Worker    else:
2052*da0073e9SAndroid Build Coastguard Worker        return _gradcheck_helper(**args)
2053*da0073e9SAndroid Build Coastguard Worker
2054*da0073e9SAndroid Build Coastguard Worker
2055*da0073e9SAndroid Build Coastguard Workerdef _gradcheck_helper(
2056*da0073e9SAndroid Build Coastguard Worker    func,
2057*da0073e9SAndroid Build Coastguard Worker    inputs,
2058*da0073e9SAndroid Build Coastguard Worker    eps,
2059*da0073e9SAndroid Build Coastguard Worker    atol,
2060*da0073e9SAndroid Build Coastguard Worker    rtol,
2061*da0073e9SAndroid Build Coastguard Worker    nondet_tol,
2062*da0073e9SAndroid Build Coastguard Worker    check_undefined_grad,
2063*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes,
2064*da0073e9SAndroid Build Coastguard Worker    check_batched_grad,
2065*da0073e9SAndroid Build Coastguard Worker    check_batched_forward_grad,
2066*da0073e9SAndroid Build Coastguard Worker    check_forward_ad,
2067*da0073e9SAndroid Build Coastguard Worker    check_backward_ad,
2068*da0073e9SAndroid Build Coastguard Worker    fast_mode,
2069*da0073e9SAndroid Build Coastguard Worker    masked,
2070*da0073e9SAndroid Build Coastguard Worker):
2071*da0073e9SAndroid Build Coastguard Worker    tupled_inputs = _as_tuple(inputs)
2072*da0073e9SAndroid Build Coastguard Worker    _check_inputs(tupled_inputs)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker    func_out = func(*tupled_inputs)
2075*da0073e9SAndroid Build Coastguard Worker    outputs = _differentiable_outputs(func_out)
2076*da0073e9SAndroid Build Coastguard Worker    _check_outputs(outputs)
2077*da0073e9SAndroid Build Coastguard Worker
2078*da0073e9SAndroid Build Coastguard Worker    gradcheck_fn = functools.partial(
2079*da0073e9SAndroid Build Coastguard Worker        _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked
2080*da0073e9SAndroid Build Coastguard Worker    )
2081*da0073e9SAndroid Build Coastguard Worker    _gradcheck_real_imag(
2082*da0073e9SAndroid Build Coastguard Worker        gradcheck_fn,
2083*da0073e9SAndroid Build Coastguard Worker        func,
2084*da0073e9SAndroid Build Coastguard Worker        func_out,
2085*da0073e9SAndroid Build Coastguard Worker        tupled_inputs,
2086*da0073e9SAndroid Build Coastguard Worker        outputs,
2087*da0073e9SAndroid Build Coastguard Worker        eps,
2088*da0073e9SAndroid Build Coastguard Worker        rtol,
2089*da0073e9SAndroid Build Coastguard Worker        atol,
2090*da0073e9SAndroid Build Coastguard Worker        check_grad_dtypes,
2091*da0073e9SAndroid Build Coastguard Worker        check_forward_ad=check_forward_ad,
2092*da0073e9SAndroid Build Coastguard Worker        check_backward_ad=check_backward_ad,
2093*da0073e9SAndroid Build Coastguard Worker        nondet_tol=nondet_tol,
2094*da0073e9SAndroid Build Coastguard Worker        check_undefined_grad=check_undefined_grad,
2095*da0073e9SAndroid Build Coastguard Worker    )
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker    if check_batched_forward_grad:
2098*da0073e9SAndroid Build Coastguard Worker        _test_batched_grad_forward_ad(func, tupled_inputs)
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker    # Short circuit because remaining tests rely on backward AD to be implemented
2101*da0073e9SAndroid Build Coastguard Worker    if not check_backward_ad:
2102*da0073e9SAndroid Build Coastguard Worker        return True
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker    for i, o in enumerate(outputs):
2105*da0073e9SAndroid Build Coastguard Worker        if check_batched_grad:
2106*da0073e9SAndroid Build Coastguard Worker            _test_batched_grad(tupled_inputs, o, i)
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker    _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked)
2109*da0073e9SAndroid Build Coastguard Worker
2110*da0073e9SAndroid Build Coastguard Worker    if check_undefined_grad and check_backward_ad:
2111*da0073e9SAndroid Build Coastguard Worker        _test_undefined_backward_mode(func, outputs, tupled_inputs)
2112*da0073e9SAndroid Build Coastguard Worker    return True
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Workerdef gradgradcheck(
2116*da0073e9SAndroid Build Coastguard Worker    func: Callable[..., _TensorOrTensors],  # See Note [VarArg of Tensors]
2117*da0073e9SAndroid Build Coastguard Worker    inputs: _TensorOrTensors,
2118*da0073e9SAndroid Build Coastguard Worker    grad_outputs: Optional[_TensorOrTensors] = None,
2119*da0073e9SAndroid Build Coastguard Worker    *,
2120*da0073e9SAndroid Build Coastguard Worker    eps: float = 1e-6,
2121*da0073e9SAndroid Build Coastguard Worker    atol: float = 1e-5,
2122*da0073e9SAndroid Build Coastguard Worker    rtol: float = 1e-3,
2123*da0073e9SAndroid Build Coastguard Worker    gen_non_contig_grad_outputs: bool = False,
2124*da0073e9SAndroid Build Coastguard Worker    raise_exception: bool = True,
2125*da0073e9SAndroid Build Coastguard Worker    nondet_tol: float = 0.0,
2126*da0073e9SAndroid Build Coastguard Worker    check_undefined_grad: bool = True,
2127*da0073e9SAndroid Build Coastguard Worker    check_grad_dtypes: bool = False,
2128*da0073e9SAndroid Build Coastguard Worker    check_batched_grad: bool = False,
2129*da0073e9SAndroid Build Coastguard Worker    check_fwd_over_rev: bool = False,
2130*da0073e9SAndroid Build Coastguard Worker    check_rev_over_rev: bool = True,
2131*da0073e9SAndroid Build Coastguard Worker    fast_mode: bool = False,
2132*da0073e9SAndroid Build Coastguard Worker    masked: bool = False,
2133*da0073e9SAndroid Build Coastguard Worker) -> bool:  # noqa: D400,D205
2134*da0073e9SAndroid Build Coastguard Worker    r"""Check gradients of gradients computed via small finite differences
2135*da0073e9SAndroid Build Coastguard Worker    against analytical gradients wrt tensors in :attr:`inputs` and
2136*da0073e9SAndroid Build Coastguard Worker    :attr:`grad_outputs` that are of floating point or complex type and with
2137*da0073e9SAndroid Build Coastguard Worker    ``requires_grad=True``.
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker    This function checks that backpropagating through the gradients computed
2140*da0073e9SAndroid Build Coastguard Worker    to the given :attr:`grad_outputs` are correct.
2141*da0073e9SAndroid Build Coastguard Worker
2142*da0073e9SAndroid Build Coastguard Worker    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Worker    .. note::
2145*da0073e9SAndroid Build Coastguard Worker        The default values are designed for :attr:`input` and
2146*da0073e9SAndroid Build Coastguard Worker        :attr:`grad_outputs` of double precision. This check will likely fail if
2147*da0073e9SAndroid Build Coastguard Worker        they are of less precision, e.g., ``FloatTensor``.
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker    .. warning::
2150*da0073e9SAndroid Build Coastguard Worker       If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
2151*da0073e9SAndroid Build Coastguard Worker       overlapping memory, i.e., different indices pointing to the same memory
2152*da0073e9SAndroid Build Coastguard Worker       address (e.g., from :func:`torch.expand`), this check will likely fail
2153*da0073e9SAndroid Build Coastguard Worker       because the numerical gradients computed by point perturbation at such
2154*da0073e9SAndroid Build Coastguard Worker       indices will change values at all other indices that share the same
2155*da0073e9SAndroid Build Coastguard Worker       memory address.
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker    Args:
2158*da0073e9SAndroid Build Coastguard Worker        func (function): a Python function that takes Tensor inputs and returns
2159*da0073e9SAndroid Build Coastguard Worker            a Tensor or a tuple of Tensors
2160*da0073e9SAndroid Build Coastguard Worker        inputs (tuple of Tensor or Tensor): inputs to the function
2161*da0073e9SAndroid Build Coastguard Worker        grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
2162*da0073e9SAndroid Build Coastguard Worker            respect to the function's outputs.
2163*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): perturbation for finite differences
2164*da0073e9SAndroid Build Coastguard Worker        atol (float, optional): absolute tolerance
2165*da0073e9SAndroid Build Coastguard Worker        rtol (float, optional): relative tolerance
2166*da0073e9SAndroid Build Coastguard Worker        gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
2167*da0073e9SAndroid Build Coastguard Worker            ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
2168*da0073e9SAndroid Build Coastguard Worker            randomly generated gradient outputs are made to be noncontiguous
2169*da0073e9SAndroid Build Coastguard Worker        raise_exception (bool, optional): indicating whether to raise an exception if
2170*da0073e9SAndroid Build Coastguard Worker            the check fails. The exception gives more information about the
2171*da0073e9SAndroid Build Coastguard Worker            exact nature of the failure. This is helpful when debugging gradchecks.
2172*da0073e9SAndroid Build Coastguard Worker        nondet_tol (float, optional): tolerance for non-determinism. When running
2173*da0073e9SAndroid Build Coastguard Worker            identical inputs through the differentiation, the results must either match
2174*da0073e9SAndroid Build Coastguard Worker            exactly (default, 0.0) or be within this tolerance. Note that a small amount
2175*da0073e9SAndroid Build Coastguard Worker            of nondeterminism in the gradient will lead to larger inaccuracies in
2176*da0073e9SAndroid Build Coastguard Worker            the second derivative.
2177*da0073e9SAndroid Build Coastguard Worker        check_undefined_grad (bool, optional): if True, check if undefined output grads
2178*da0073e9SAndroid Build Coastguard Worker            are supported and treated as zeros
2179*da0073e9SAndroid Build Coastguard Worker        check_batched_grad (bool, optional): if True, check if we can compute
2180*da0073e9SAndroid Build Coastguard Worker            batched gradients using prototype vmap support. Defaults to False.
2181*da0073e9SAndroid Build Coastguard Worker        fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that
2182*da0073e9SAndroid Build Coastguard Worker            no longer computes the entire jacobian.
2183*da0073e9SAndroid Build Coastguard Worker        masked (bool, optional): if True, the gradients of unspecified elements of
2184*da0073e9SAndroid Build Coastguard Worker            sparse tensors are ignored (default, False).
2185*da0073e9SAndroid Build Coastguard Worker    Returns:
2186*da0073e9SAndroid Build Coastguard Worker        True if all differences satisfy allclose condition
2187*da0073e9SAndroid Build Coastguard Worker    """
2188*da0073e9SAndroid Build Coastguard Worker    assert (
2189*da0073e9SAndroid Build Coastguard Worker        check_fwd_over_rev or check_rev_over_rev
2190*da0073e9SAndroid Build Coastguard Worker    ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True"
2191*da0073e9SAndroid Build Coastguard Worker    assert not (
2192*da0073e9SAndroid Build Coastguard Worker        check_undefined_grad and not check_rev_over_rev
2193*da0073e9SAndroid Build Coastguard Worker    ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True"
2194*da0073e9SAndroid Build Coastguard Worker    assert not (
2195*da0073e9SAndroid Build Coastguard Worker        check_batched_grad and not check_rev_over_rev
2196*da0073e9SAndroid Build Coastguard Worker    ), "Setting check_batched_grad=True requires check_rev_over_rev to be True"
2197*da0073e9SAndroid Build Coastguard Worker    # TODO: do we want to test this too?
2198*da0073e9SAndroid Build Coastguard Worker    # assert not (check_batched_forward_grad and not check_fwd_over_rev), (
2199*da0073e9SAndroid Build Coastguard Worker    #     "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True")
2200*da0073e9SAndroid Build Coastguard Worker    tupled_inputs = _as_tuple(inputs)
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker    if grad_outputs is None:
2203*da0073e9SAndroid Build Coastguard Worker        # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker        outputs = _differentiable_outputs(func(*tupled_inputs))
2206*da0073e9SAndroid Build Coastguard Worker        tupled_grad_outputs = tuple(
2207*da0073e9SAndroid Build Coastguard Worker            torch.testing.make_tensor(
2208*da0073e9SAndroid Build Coastguard Worker                x.shape,
2209*da0073e9SAndroid Build Coastguard Worker                dtype=x.dtype
2210*da0073e9SAndroid Build Coastguard Worker                if x.is_floating_point() or x.is_complex()
2211*da0073e9SAndroid Build Coastguard Worker                else torch.double,
2212*da0073e9SAndroid Build Coastguard Worker                device=x.device,
2213*da0073e9SAndroid Build Coastguard Worker                low=-1,
2214*da0073e9SAndroid Build Coastguard Worker                high=1,
2215*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
2216*da0073e9SAndroid Build Coastguard Worker                noncontiguous=gen_non_contig_grad_outputs,
2217*da0073e9SAndroid Build Coastguard Worker            )
2218*da0073e9SAndroid Build Coastguard Worker            for x in outputs
2219*da0073e9SAndroid Build Coastguard Worker        )
2220*da0073e9SAndroid Build Coastguard Worker    else:
2221*da0073e9SAndroid Build Coastguard Worker        tupled_grad_outputs = _as_tuple(grad_outputs)
2222*da0073e9SAndroid Build Coastguard Worker
2223*da0073e9SAndroid Build Coastguard Worker    num_outputs = len(tupled_grad_outputs)
2224*da0073e9SAndroid Build Coastguard Worker
2225*da0073e9SAndroid Build Coastguard Worker    # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
2226*da0073e9SAndroid Build Coastguard Worker    #     before running forward mode AD
2227*da0073e9SAndroid Build Coastguard Worker    diff_input_args_indices = {
2228*da0073e9SAndroid Build Coastguard Worker        i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad
2229*da0073e9SAndroid Build Coastguard Worker    }
2230*da0073e9SAndroid Build Coastguard Worker    diff_grad_output_indices = {
2231*da0073e9SAndroid Build Coastguard Worker        i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad
2232*da0073e9SAndroid Build Coastguard Worker    }
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker    def new_func(*args):
2235*da0073e9SAndroid Build Coastguard Worker        # Restore the requires_grad information
2236*da0073e9SAndroid Build Coastguard Worker        input_args = tuple(
2237*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_() if i in diff_input_args_indices else x
2238*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(args[:-num_outputs])
2239*da0073e9SAndroid Build Coastguard Worker        )
2240*da0073e9SAndroid Build Coastguard Worker        outputs = _differentiable_outputs(func(*input_args))
2241*da0073e9SAndroid Build Coastguard Worker        grad_outputs = tuple(
2242*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_() if i in diff_grad_output_indices else x
2243*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(args[-num_outputs:])
2244*da0073e9SAndroid Build Coastguard Worker        )
2245*da0073e9SAndroid Build Coastguard Worker        diff_input_args = tuple(
2246*da0073e9SAndroid Build Coastguard Worker            x for i, x in enumerate(input_args) if i in diff_input_args_indices
2247*da0073e9SAndroid Build Coastguard Worker        )
2248*da0073e9SAndroid Build Coastguard Worker        grad_inputs = torch.autograd.grad(
2249*da0073e9SAndroid Build Coastguard Worker            outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True
2250*da0073e9SAndroid Build Coastguard Worker        )
2251*da0073e9SAndroid Build Coastguard Worker        grad_inputs = tuple(g for g in grad_inputs if g is not None)
2252*da0073e9SAndroid Build Coastguard Worker        return grad_inputs
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker    return gradcheck(
2255*da0073e9SAndroid Build Coastguard Worker        new_func,
2256*da0073e9SAndroid Build Coastguard Worker        tupled_inputs + tupled_grad_outputs,
2257*da0073e9SAndroid Build Coastguard Worker        eps=eps,
2258*da0073e9SAndroid Build Coastguard Worker        atol=atol,
2259*da0073e9SAndroid Build Coastguard Worker        rtol=rtol,
2260*da0073e9SAndroid Build Coastguard Worker        raise_exception=raise_exception,
2261*da0073e9SAndroid Build Coastguard Worker        nondet_tol=nondet_tol,
2262*da0073e9SAndroid Build Coastguard Worker        check_undefined_grad=check_undefined_grad,
2263*da0073e9SAndroid Build Coastguard Worker        check_grad_dtypes=check_grad_dtypes,
2264*da0073e9SAndroid Build Coastguard Worker        check_batched_grad=check_batched_grad,
2265*da0073e9SAndroid Build Coastguard Worker        fast_mode=fast_mode,
2266*da0073e9SAndroid Build Coastguard Worker        check_forward_ad=check_fwd_over_rev,
2267*da0073e9SAndroid Build Coastguard Worker        check_backward_ad=check_rev_over_rev,
2268*da0073e9SAndroid Build Coastguard Worker        masked=masked,
2269*da0073e9SAndroid Build Coastguard Worker    )
2270