1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport collections 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport warnings 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Iterable, List, Optional, Tuple, Union 7*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.testing 11*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import _vmap, vmap 12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import is_tensor_like 13*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _TensorOrTensors 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public 17*da0073e9SAndroid Build Coastguard Worker# since they have been exposed from before we added `__all__` and we already maintain BC for them 18*da0073e9SAndroid Build Coastguard Worker# We should eventually deprecate them and remove them from `__all__` 19*da0073e9SAndroid Build Coastguard Worker__all__ = [ 20*da0073e9SAndroid Build Coastguard Worker "gradcheck", 21*da0073e9SAndroid Build Coastguard Worker "gradgradcheck", 22*da0073e9SAndroid Build Coastguard Worker "GradcheckError", 23*da0073e9SAndroid Build Coastguard Worker "get_numerical_jacobian", 24*da0073e9SAndroid Build Coastguard Worker "get_analytical_jacobian", 25*da0073e9SAndroid Build Coastguard Worker "get_numerical_jacobian_wrt_specific_input", 26*da0073e9SAndroid Build Coastguard Worker] 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerclass GradcheckError(RuntimeError): 30*da0073e9SAndroid Build Coastguard Worker r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`.""" 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerdef _is_sparse_compressed_tensor(obj: torch.Tensor): 34*da0073e9SAndroid Build Coastguard Worker return obj.layout in { 35*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 36*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 37*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 38*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 39*da0073e9SAndroid Build Coastguard Worker } 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerdef _is_sparse_any_tensor(obj: torch.Tensor): 43*da0073e9SAndroid Build Coastguard Worker return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerdef _is_float_or_complex_tensor(obj): 47*da0073e9SAndroid Build Coastguard Worker return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerdef _allocate_jacobians_with_inputs( 51*da0073e9SAndroid Build Coastguard Worker input_tensors: Tuple, numel_output 52*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 53*da0073e9SAndroid Build Coastguard Worker # Makes zero-filled tensors from inputs. If `numel_output` is not None, for 54*da0073e9SAndroid Build Coastguard Worker # each tensor in `input_tensors`, returns a new zero-filled tensor with height 55*da0073e9SAndroid Build Coastguard Worker # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns 56*da0073e9SAndroid Build Coastguard Worker # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have 57*da0073e9SAndroid Build Coastguard Worker # the same dtype and device as those of the corresponding input. 58*da0073e9SAndroid Build Coastguard Worker out: List[torch.Tensor] = [] 59*da0073e9SAndroid Build Coastguard Worker for t in input_tensors: 60*da0073e9SAndroid Build Coastguard Worker if _is_float_or_complex_tensor(t) and t.requires_grad: 61*da0073e9SAndroid Build Coastguard Worker out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided)) 62*da0073e9SAndroid Build Coastguard Worker return tuple(out) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Workerdef _allocate_jacobians_with_outputs( 66*da0073e9SAndroid Build Coastguard Worker output_tensors: Tuple, numel_input, dtype=None, device=None 67*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 68*da0073e9SAndroid Build Coastguard Worker # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor 69*da0073e9SAndroid Build Coastguard Worker # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and 70*da0073e9SAndroid Build Coastguard Worker # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size 71*da0073e9SAndroid Build Coastguard Worker # (t.numel,). 72*da0073e9SAndroid Build Coastguard Worker out: List[torch.Tensor] = [] 73*da0073e9SAndroid Build Coastguard Worker options = {"dtype": dtype, "device": device, "layout": torch.strided} 74*da0073e9SAndroid Build Coastguard Worker for t in output_tensors: 75*da0073e9SAndroid Build Coastguard Worker if _is_float_or_complex_tensor(t): 76*da0073e9SAndroid Build Coastguard Worker out.append(t.new_zeros((numel_input, t.numel()), **options)) 77*da0073e9SAndroid Build Coastguard Worker return tuple(out) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef _iter_tensors( 81*da0073e9SAndroid Build Coastguard Worker x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False 82*da0073e9SAndroid Build Coastguard Worker) -> Iterable[torch.Tensor]: 83*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(x): 84*da0073e9SAndroid Build Coastguard Worker # mypy doesn't narrow type of `x` to torch.Tensor 85*da0073e9SAndroid Build Coastguard Worker if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr] 86*da0073e9SAndroid Build Coastguard Worker yield x # type: ignore[misc] 87*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 88*da0073e9SAndroid Build Coastguard Worker for elem in x: 89*da0073e9SAndroid Build Coastguard Worker yield from _iter_tensors(elem, only_requiring_grad) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerdef _densify(x): 93*da0073e9SAndroid Build Coastguard Worker # return a copy of sparse x with all unspecified elements 94*da0073e9SAndroid Build Coastguard Worker # "replaced" with zero-valued elements 95*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (list, tuple)): 96*da0073e9SAndroid Build Coastguard Worker return type(x)(map(_densify, x)) 97*da0073e9SAndroid Build Coastguard Worker elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}: # type: ignore[attr-defined] # no attr _mkldnn 98*da0073e9SAndroid Build Coastguard Worker return x 99*da0073e9SAndroid Build Coastguard Worker elif x.layout is torch.sparse_coo: 100*da0073e9SAndroid Build Coastguard Worker device = x.device 101*da0073e9SAndroid Build Coastguard Worker indices_dtype = x._indices().dtype 102*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device) 103*da0073e9SAndroid Build Coastguard Worker indices = tmp.nonzero().t().to(dtype=indices_dtype) 104*da0073e9SAndroid Build Coastguard Worker values = torch.zeros( 105*da0073e9SAndroid Build Coastguard Worker (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker x_coalesced = x.detach().coalesce() 108*da0073e9SAndroid Build Coastguard Worker if x_coalesced.numel() > 0: 109*da0073e9SAndroid Build Coastguard Worker stride = tmp.stride() 110*da0073e9SAndroid Build Coastguard Worker flat_indices = ( 111*da0073e9SAndroid Build Coastguard Worker x_coalesced.indices() 112*da0073e9SAndroid Build Coastguard Worker .mul( 113*da0073e9SAndroid Build Coastguard Worker torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze( 114*da0073e9SAndroid Build Coastguard Worker 1 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker .sum(0) 118*da0073e9SAndroid Build Coastguard Worker ) 119*da0073e9SAndroid Build Coastguard Worker values[flat_indices] = x_coalesced.values() 120*da0073e9SAndroid Build Coastguard Worker return ( 121*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, x.shape) 122*da0073e9SAndroid Build Coastguard Worker ._coalesced_(True) 123*da0073e9SAndroid Build Coastguard Worker .requires_grad_(x.requires_grad) 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_compressed_tensor(x): 126*da0073e9SAndroid Build Coastguard Worker blocksize = ( 127*da0073e9SAndroid Build Coastguard Worker x.values().shape[1:3] 128*da0073e9SAndroid Build Coastguard Worker if x.layout in {torch.sparse_bsr, torch.sparse_bsc} 129*da0073e9SAndroid Build Coastguard Worker else None 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker compressed_indices = ( 132*da0073e9SAndroid Build Coastguard Worker x.crow_indices() 133*da0073e9SAndroid Build Coastguard Worker if x.layout in {torch.sparse_csr, torch.sparse_bsr} 134*da0073e9SAndroid Build Coastguard Worker else x.ccol_indices() 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker # We'll use intermediate sparse COO for simplicity 137*da0073e9SAndroid Build Coastguard Worker r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse( 138*da0073e9SAndroid Build Coastguard Worker layout=x.layout, blocksize=blocksize 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker # Check that all elements are specified also after `to_sparse` op: 141*da0073e9SAndroid Build Coastguard Worker dense_numel = r.values().numel() // max(1, r.values().shape[0]) 142*da0073e9SAndroid Build Coastguard Worker batch_numel = compressed_indices.numel() // compressed_indices.shape[-1] 143*da0073e9SAndroid Build Coastguard Worker sparse_numel = r.numel() // max(1, dense_numel * batch_numel) 144*da0073e9SAndroid Build Coastguard Worker if sparse_numel != r._nnz(): 145*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 146*da0073e9SAndroid Build Coastguard Worker f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}" 147*da0073e9SAndroid Build Coastguard Worker ) 148*da0073e9SAndroid Build Coastguard Worker return r.requires_grad_(x.requires_grad) 149*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_any_tensor(x): 150*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(x.layout) 151*da0073e9SAndroid Build Coastguard Worker return x 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Workerdef _iter_tensor(x_tensor): 155*da0073e9SAndroid Build Coastguard Worker # (Only used for slow gradcheck) Returns a generator that yields the following 156*da0073e9SAndroid Build Coastguard Worker # elements at each iteration: 157*da0073e9SAndroid Build Coastguard Worker # 1) a tensor: the same tensor is returned across all iterations. The tensor 158*da0073e9SAndroid Build Coastguard Worker # is not the same as the original x_tensor as given as input - it is 159*da0073e9SAndroid Build Coastguard Worker # prepared so that it can be modified in-place. Depending on whether the 160*da0073e9SAndroid Build Coastguard Worker # input tensor is strided, sparse, or dense, the returned tensor may or may 161*da0073e9SAndroid Build Coastguard Worker # not share storage with x_tensor. 162*da0073e9SAndroid Build Coastguard Worker # 2) a tuple of indices that can be used with advanced indexing (yielded in 163*da0073e9SAndroid Build Coastguard Worker # dictionary order) 164*da0073e9SAndroid Build Coastguard Worker # 3) flattened index that will be used to index into the Jacobian tensor 165*da0073e9SAndroid Build Coastguard Worker # 166*da0073e9SAndroid Build Coastguard Worker # For a tensor t with size (2, 2), _iter_tensor yields: 167*da0073e9SAndroid Build Coastguard Worker # `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3` 168*da0073e9SAndroid Build Coastguard Worker # 169*da0073e9SAndroid Build Coastguard Worker # where x is the t.data of the original tensor. Perturbing the entry of x 170*da0073e9SAndroid Build Coastguard Worker # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. 171*da0073e9SAndroid Build Coastguard Worker if _is_sparse_any_tensor(x_tensor): 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def get_stride(size): 174*da0073e9SAndroid Build Coastguard Worker dim = len(size) 175*da0073e9SAndroid Build Coastguard Worker tmp = 1 176*da0073e9SAndroid Build Coastguard Worker stride = [0] * dim 177*da0073e9SAndroid Build Coastguard Worker for i in reversed(range(dim)): 178*da0073e9SAndroid Build Coastguard Worker stride[i] = tmp 179*da0073e9SAndroid Build Coastguard Worker tmp *= size[i] 180*da0073e9SAndroid Build Coastguard Worker return stride 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker x_nnz = x_tensor._nnz() 183*da0073e9SAndroid Build Coastguard Worker x_size = list(x_tensor.size()) 184*da0073e9SAndroid Build Coastguard Worker if x_tensor.layout is torch.sparse_coo: 185*da0073e9SAndroid Build Coastguard Worker x_indices = x_tensor._indices().t() 186*da0073e9SAndroid Build Coastguard Worker x_values = x_tensor._values() 187*da0073e9SAndroid Build Coastguard Worker elif x_tensor.layout is torch.sparse_csr: 188*da0073e9SAndroid Build Coastguard Worker x_indices = torch._convert_indices_from_csr_to_coo( 189*da0073e9SAndroid Build Coastguard Worker x_tensor.crow_indices(), x_tensor.col_indices() 190*da0073e9SAndroid Build Coastguard Worker ).t() 191*da0073e9SAndroid Build Coastguard Worker x_values = x_tensor.values() 192*da0073e9SAndroid Build Coastguard Worker elif x_tensor.layout is torch.sparse_csc: 193*da0073e9SAndroid Build Coastguard Worker x_indices = torch._convert_indices_from_csr_to_coo( 194*da0073e9SAndroid Build Coastguard Worker x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True 195*da0073e9SAndroid Build Coastguard Worker ).t() 196*da0073e9SAndroid Build Coastguard Worker x_values = x_tensor.values() 197*da0073e9SAndroid Build Coastguard Worker elif x_tensor.layout is torch.sparse_bsr: 198*da0073e9SAndroid Build Coastguard Worker x_block_values = x_tensor.values() 199*da0073e9SAndroid Build Coastguard Worker x_blocksize = x_block_values.size()[1:3] 200*da0073e9SAndroid Build Coastguard Worker x_indices = ( 201*da0073e9SAndroid Build Coastguard Worker torch._convert_indices_from_csr_to_coo( 202*da0073e9SAndroid Build Coastguard Worker x_tensor.crow_indices(), x_tensor.col_indices() 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) 205*da0073e9SAndroid Build Coastguard Worker .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) 206*da0073e9SAndroid Build Coastguard Worker .add_( 207*da0073e9SAndroid Build Coastguard Worker torch.stack( 208*da0073e9SAndroid Build Coastguard Worker torch.where(torch.ones(x_blocksize, device=x_tensor.device)) 209*da0073e9SAndroid Build Coastguard Worker ).repeat(1, x_nnz) 210*da0073e9SAndroid Build Coastguard Worker ) 211*da0073e9SAndroid Build Coastguard Worker .t() 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker x_values = x_block_values.flatten(0, 2) 214*da0073e9SAndroid Build Coastguard Worker x_nnz = x_values.size(0) 215*da0073e9SAndroid Build Coastguard Worker elif x_tensor.layout is torch.sparse_bsc: 216*da0073e9SAndroid Build Coastguard Worker x_block_values = x_tensor.values() 217*da0073e9SAndroid Build Coastguard Worker x_blocksize = x_block_values.size()[1:3] 218*da0073e9SAndroid Build Coastguard Worker x_indices = ( 219*da0073e9SAndroid Build Coastguard Worker torch._convert_indices_from_csr_to_coo( 220*da0073e9SAndroid Build Coastguard Worker x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True 221*da0073e9SAndroid Build Coastguard Worker ) 222*da0073e9SAndroid Build Coastguard Worker .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) 223*da0073e9SAndroid Build Coastguard Worker .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) 224*da0073e9SAndroid Build Coastguard Worker .add_( 225*da0073e9SAndroid Build Coastguard Worker torch.stack( 226*da0073e9SAndroid Build Coastguard Worker torch.where(torch.ones(x_blocksize, device=x_tensor.device)) 227*da0073e9SAndroid Build Coastguard Worker ).repeat(1, x_nnz) 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker .t() 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker x_values = x_block_values.flatten(0, 2) 232*da0073e9SAndroid Build Coastguard Worker x_nnz = x_values.size(0) 233*da0073e9SAndroid Build Coastguard Worker else: 234*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input") 235*da0073e9SAndroid Build Coastguard Worker x_stride = get_stride(x_size) 236*da0073e9SAndroid Build Coastguard Worker # Use .data here to get around the version check 237*da0073e9SAndroid Build Coastguard Worker x_values = x_values.data 238*da0073e9SAndroid Build Coastguard Worker for i in range(x_nnz): 239*da0073e9SAndroid Build Coastguard Worker x_value = x_values[i] 240*da0073e9SAndroid Build Coastguard Worker for x_idx in product(*[range(m) for m in x_values.size()[1:]]): 241*da0073e9SAndroid Build Coastguard Worker indices = x_indices[i].tolist() + list(x_idx) 242*da0073e9SAndroid Build Coastguard Worker d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) 243*da0073e9SAndroid Build Coastguard Worker yield x_value, x_idx, d_idx 244*da0073e9SAndroid Build Coastguard Worker elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined] 245*da0073e9SAndroid Build Coastguard Worker for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): 246*da0073e9SAndroid Build Coastguard Worker # this is really inefficient, but without indexing implemented, there's 247*da0073e9SAndroid Build Coastguard Worker # not really a better way than converting back and forth 248*da0073e9SAndroid Build Coastguard Worker x_tensor_dense = x_tensor.to_dense() 249*da0073e9SAndroid Build Coastguard Worker yield x_tensor_dense, x_idx, d_idx 250*da0073e9SAndroid Build Coastguard Worker else: 251*da0073e9SAndroid Build Coastguard Worker # Use .data here to get around the version check 252*da0073e9SAndroid Build Coastguard Worker x_tensor = x_tensor.data 253*da0073e9SAndroid Build Coastguard Worker for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): 254*da0073e9SAndroid Build Coastguard Worker yield x_tensor, x_idx, d_idx 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jacobian( 258*da0073e9SAndroid Build Coastguard Worker fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False 259*da0073e9SAndroid Build Coastguard Worker) -> List[Tuple[torch.Tensor, ...]]: 260*da0073e9SAndroid Build Coastguard Worker """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`. 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker If not specified, targets are the input. Returns M * N Jacobians where N is the 263*da0073e9SAndroid Build Coastguard Worker number of tensors in target that require grad and M is the number of non-integral 264*da0073e9SAndroid Build Coastguard Worker outputs. 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker Args: 267*da0073e9SAndroid Build Coastguard Worker fn: the function to compute the jacobian for 268*da0073e9SAndroid Build Coastguard Worker inputs: inputs to `fn` 269*da0073e9SAndroid Build Coastguard Worker outputs: provide precomputed outputs to avoid one extra invocation of fn 270*da0073e9SAndroid Build Coastguard Worker target: the Tensors wrt whom Jacobians are calculated (default=`inputs`) 271*da0073e9SAndroid Build Coastguard Worker eps: the magnitude of the perturbation during finite differencing 272*da0073e9SAndroid Build Coastguard Worker (default=`1e-3`) 273*da0073e9SAndroid Build Coastguard Worker is_forward_ad: if this numerical jacobian is computed to be checked wrt 274*da0073e9SAndroid Build Coastguard Worker forward AD gradients (this is used for error checking only) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker Returns: 277*da0073e9SAndroid Build Coastguard Worker A list of M N-tuples of tensors 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker Note that `target` may not even be part of `input` to `fn`, so please be 280*da0073e9SAndroid Build Coastguard Worker **very careful** in this to not clone `target`. 281*da0073e9SAndroid Build Coastguard Worker """ 282*da0073e9SAndroid Build Coastguard Worker jacobians: List[Tuple[torch.Tensor, ...]] = [] 283*da0073e9SAndroid Build Coastguard Worker if outputs is None: 284*da0073e9SAndroid Build Coastguard Worker outputs = _as_tuple(fn(*_as_tuple(inputs))) 285*da0073e9SAndroid Build Coastguard Worker if not is_forward_ad and any(o.is_complex() for o in outputs): 286*da0073e9SAndroid Build Coastguard Worker raise ValueError( 287*da0073e9SAndroid Build Coastguard Worker "Expected output to be non-complex. get_numerical_jacobian no " 288*da0073e9SAndroid Build Coastguard Worker "longer supports functions that return complex outputs." 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker if target is None: 291*da0073e9SAndroid Build Coastguard Worker target = inputs 292*da0073e9SAndroid Build Coastguard Worker inp_indices = [ 293*da0073e9SAndroid Build Coastguard Worker i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad 294*da0073e9SAndroid Build Coastguard Worker ] 295*da0073e9SAndroid Build Coastguard Worker for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): 296*da0073e9SAndroid Build Coastguard Worker jacobians += [ 297*da0073e9SAndroid Build Coastguard Worker get_numerical_jacobian_wrt_specific_input( 298*da0073e9SAndroid Build Coastguard Worker fn, 299*da0073e9SAndroid Build Coastguard Worker inp_idx, 300*da0073e9SAndroid Build Coastguard Worker inputs, 301*da0073e9SAndroid Build Coastguard Worker outputs, 302*da0073e9SAndroid Build Coastguard Worker eps, 303*da0073e9SAndroid Build Coastguard Worker input=inp, 304*da0073e9SAndroid Build Coastguard Worker is_forward_ad=is_forward_ad, 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker ] 307*da0073e9SAndroid Build Coastguard Worker return jacobians 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker@deprecated( 311*da0073e9SAndroid Build Coastguard Worker "`get_numerical_jacobian` was part of PyTorch's private API and not " 312*da0073e9SAndroid Build Coastguard Worker "meant to be exposed. We are deprecating it and it will be removed " 313*da0073e9SAndroid Build Coastguard Worker "in a future version of PyTorch. If you have a specific use for " 314*da0073e9SAndroid Build Coastguard Worker "this or feature request for this to be a stable API, please file " 315*da0073e9SAndroid Build Coastguard Worker "us an issue at https://github.com/pytorch/pytorch/issues/new", 316*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 317*da0073e9SAndroid Build Coastguard Worker) 318*da0073e9SAndroid Build Coastguard Workerdef get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): 319*da0073e9SAndroid Build Coastguard Worker """Compute the numerical Jacobian for a given fn and its inputs. 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker This is a Deprecated API. 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker Args: 324*da0073e9SAndroid Build Coastguard Worker fn: the function to compute the Jacobian for (must take inputs as a tuple) 325*da0073e9SAndroid Build Coastguard Worker input: input to `fn` 326*da0073e9SAndroid Build Coastguard Worker target: the Tensors wrt whom Jacobians are calculated (default=`input`) 327*da0073e9SAndroid Build Coastguard Worker eps: the magnitude of the perturbation during finite differencing 328*da0073e9SAndroid Build Coastguard Worker (default=`1e-3`) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker Returns: 331*da0073e9SAndroid Build Coastguard Worker A list of Jacobians of `fn` (restricted to its first output) with respect to 332*da0073e9SAndroid Build Coastguard Worker each input or target, if provided. 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker Note that `target` may not even be part of `input` to `fn`, so please be 335*da0073e9SAndroid Build Coastguard Worker **very careful** in this to not clone `target`. 336*da0073e9SAndroid Build Coastguard Worker """ 337*da0073e9SAndroid Build Coastguard Worker if ( 338*da0073e9SAndroid Build Coastguard Worker grad_out != 1.0 339*da0073e9SAndroid Build Coastguard Worker ): # grad_out param is only kept for backward compatibility reasons 340*da0073e9SAndroid Build Coastguard Worker raise ValueError( 341*da0073e9SAndroid Build Coastguard Worker "Expected grad_out to be 1.0. get_numerical_jacobian no longer " 342*da0073e9SAndroid Build Coastguard Worker "supports values of grad_out != 1.0." 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker def fn_pack_inps(*inps): 346*da0073e9SAndroid Build Coastguard Worker return fn(inps) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Workerdef _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn): 354*da0073e9SAndroid Build Coastguard Worker # Computes numerical directional derivative as finite difference 355*da0073e9SAndroid Build Coastguard Worker # of function `fn` at input `entry`, perturbed by vector `v`. 356*da0073e9SAndroid Build Coastguard Worker if _is_sparse_compressed_tensor(entry): 357*da0073e9SAndroid Build Coastguard Worker # sparse compressed tensors don't implement sub/add/copy_ 358*da0073e9SAndroid Build Coastguard Worker # yet. However, in non-masked semantics context entry and v 359*da0073e9SAndroid Build Coastguard Worker # have the same sparse indices ... 360*da0073e9SAndroid Build Coastguard Worker assert entry.layout == v.layout, (entry.layout, v.layout) 361*da0073e9SAndroid Build Coastguard Worker assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape) 362*da0073e9SAndroid Build Coastguard Worker # ... the finite differencing can be performed on values only: 363*da0073e9SAndroid Build Coastguard Worker entry = entry.values() 364*da0073e9SAndroid Build Coastguard Worker v = v.values() 365*da0073e9SAndroid Build Coastguard Worker # we'll detach to avoid backward computations that sparse 366*da0073e9SAndroid Build Coastguard Worker # tensors have limited support for. 367*da0073e9SAndroid Build Coastguard Worker entry = entry.detach() 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker orig = entry.clone() 370*da0073e9SAndroid Build Coastguard Worker entry.copy_(orig - v) 371*da0073e9SAndroid Build Coastguard Worker outa = fn() 372*da0073e9SAndroid Build Coastguard Worker entry.copy_(orig + v) 373*da0073e9SAndroid Build Coastguard Worker outb = fn() 374*da0073e9SAndroid Build Coastguard Worker entry.copy_(orig) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def compute(a, b): 377*da0073e9SAndroid Build Coastguard Worker nbhd_checks_fn(a, b) 378*da0073e9SAndroid Build Coastguard Worker ret = (b - a) / (2 * norm_v) # use central difference approx 379*da0073e9SAndroid Build Coastguard Worker return ret.detach().reshape(-1) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker return tuple(compute(a, b) for (a, b) in zip(outa, outb)) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Workerdef _compute_numerical_jvps_wrt_specific_input( 385*da0073e9SAndroid Build Coastguard Worker jvp_fn, delta, input_is_complex, is_forward_ad=False 386*da0073e9SAndroid Build Coastguard Worker) -> List[torch.Tensor]: 387*da0073e9SAndroid Build Coastguard Worker # Computing the jacobian only works for real delta 388*da0073e9SAndroid Build Coastguard Worker # For details on the algorithm used here, refer: 389*da0073e9SAndroid Build Coastguard Worker # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf 390*da0073e9SAndroid Build Coastguard Worker # s = fn(z) where z = x for real valued input 391*da0073e9SAndroid Build Coastguard Worker # and z = x + yj for complex valued input 392*da0073e9SAndroid Build Coastguard Worker jvps: List[torch.Tensor] = [] 393*da0073e9SAndroid Build Coastguard Worker ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker if input_is_complex: # C -> R 396*da0073e9SAndroid Build Coastguard Worker ds_dy_tup = ( 397*da0073e9SAndroid Build Coastguard Worker jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup): 400*da0073e9SAndroid Build Coastguard Worker assert not ds_dx.is_complex() 401*da0073e9SAndroid Build Coastguard Worker # conjugate wirtinger derivative 402*da0073e9SAndroid Build Coastguard Worker conj_w_d = ds_dx + ds_dy * 1j 403*da0073e9SAndroid Build Coastguard Worker jvps.append(conj_w_d) 404*da0073e9SAndroid Build Coastguard Worker else: 405*da0073e9SAndroid Build Coastguard Worker for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case) 406*da0073e9SAndroid Build Coastguard Worker assert is_forward_ad or not ds_dx.is_complex() 407*da0073e9SAndroid Build Coastguard Worker jvps.append(ds_dx) 408*da0073e9SAndroid Build Coastguard Worker return jvps 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Workerdef _combine_jacobian_cols( 412*da0073e9SAndroid Build Coastguard Worker jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel 413*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 414*da0073e9SAndroid Build Coastguard Worker # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor 415*da0073e9SAndroid Build Coastguard Worker # we return a list that maps output_idx -> full jacobian Tensor 416*da0073e9SAndroid Build Coastguard Worker jacobians = _allocate_jacobians_with_outputs( 417*da0073e9SAndroid Build Coastguard Worker outputs, numel, dtype=input.dtype if input.dtype.is_complex else None 418*da0073e9SAndroid Build Coastguard Worker ) 419*da0073e9SAndroid Build Coastguard Worker for i, jacobian in enumerate(jacobians): 420*da0073e9SAndroid Build Coastguard Worker for k, v in jacobians_cols.items(): 421*da0073e9SAndroid Build Coastguard Worker jacobian[k] = v[i] 422*da0073e9SAndroid Build Coastguard Worker return jacobians 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Workerdef _prepare_input( 426*da0073e9SAndroid Build Coastguard Worker input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False 427*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 428*da0073e9SAndroid Build Coastguard Worker # Prepares the inputs to be passed into the function while including the new 429*da0073e9SAndroid Build Coastguard Worker # modified input. 430*da0073e9SAndroid Build Coastguard Worker if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn 431*da0073e9SAndroid Build Coastguard Worker # Convert back to mkldnn 432*da0073e9SAndroid Build Coastguard Worker if maybe_perturbed_input is not None: 433*da0073e9SAndroid Build Coastguard Worker return maybe_perturbed_input.to_mkldnn() 434*da0073e9SAndroid Build Coastguard Worker else: 435*da0073e9SAndroid Build Coastguard Worker return input 436*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_any_tensor(input): 437*da0073e9SAndroid Build Coastguard Worker if fast_mode and maybe_perturbed_input is not None: 438*da0073e9SAndroid Build Coastguard Worker # entry is already a "cloned" version of the original tensor 439*da0073e9SAndroid Build Coastguard Worker # thus changes to entry are not reflected in the input 440*da0073e9SAndroid Build Coastguard Worker return maybe_perturbed_input 441*da0073e9SAndroid Build Coastguard Worker else: 442*da0073e9SAndroid Build Coastguard Worker return input 443*da0073e9SAndroid Build Coastguard Worker else: 444*da0073e9SAndroid Build Coastguard Worker # We cannot use entry (input.data) if we want gradgrad to work because 445*da0073e9SAndroid Build Coastguard Worker # fn (in the gradgrad case) needs to compute grad wrt input 446*da0073e9SAndroid Build Coastguard Worker return input 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Workerdef _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None: 450*da0073e9SAndroid Build Coastguard Worker # Check that the returned outputs don't have different dtype or shape when you 451*da0073e9SAndroid Build Coastguard Worker # perturb the input 452*da0073e9SAndroid Build Coastguard Worker on_index = "on index {idx} " if idx is not None else "" 453*da0073e9SAndroid Build Coastguard Worker assert output1.shape == output2.shape, ( 454*da0073e9SAndroid Build Coastguard Worker f"Expected `func` to return outputs with the same shape" 455*da0073e9SAndroid Build Coastguard Worker f" when inputs are perturbed {on_index}by {eps}, but got:" 456*da0073e9SAndroid Build Coastguard Worker f" shapes {output1.shape} and {output2.shape}." 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker assert output1.dtype == output2.dtype, ( 459*da0073e9SAndroid Build Coastguard Worker f"Expected `func` to return outputs with the same dtype" 460*da0073e9SAndroid Build Coastguard Worker f" when inputs are perturbed {on_index}by {eps}, but got:" 461*da0073e9SAndroid Build Coastguard Worker f" dtypes {output1.dtype} and {output2.dtype}." 462*da0073e9SAndroid Build Coastguard Worker ) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Workerdef get_numerical_jacobian_wrt_specific_input( 466*da0073e9SAndroid Build Coastguard Worker fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False 467*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 468*da0073e9SAndroid Build Coastguard Worker # Computes the numerical jacobians wrt to a single input. Returns N jacobian 469*da0073e9SAndroid Build Coastguard Worker # tensors, where N is the number of outputs. We use a dictionary for 470*da0073e9SAndroid Build Coastguard Worker # jacobian_cols because indices aren't necessarily consecutive for sparse inputs 471*da0073e9SAndroid Build Coastguard Worker # When we perturb only a single element of the input tensor at a time, the jvp 472*da0073e9SAndroid Build Coastguard Worker # is equivalent to a single col of the Jacobian matrix of fn. 473*da0073e9SAndroid Build Coastguard Worker jacobian_cols: Dict[int, List[torch.Tensor]] = {} 474*da0073e9SAndroid Build Coastguard Worker input = inputs[input_idx] if input is None else input 475*da0073e9SAndroid Build Coastguard Worker assert input.requires_grad 476*da0073e9SAndroid Build Coastguard Worker for x, idx, d_idx in _iter_tensor(input): 477*da0073e9SAndroid Build Coastguard Worker wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x) 478*da0073e9SAndroid Build Coastguard Worker input_to_perturb = x[idx] 479*da0073e9SAndroid Build Coastguard Worker nbhd_checks_fn = functools.partial( 480*da0073e9SAndroid Build Coastguard Worker _check_outputs_same_dtype_and_shape, idx=idx, eps=eps 481*da0073e9SAndroid Build Coastguard Worker ) 482*da0073e9SAndroid Build Coastguard Worker jvp_fn = _get_numerical_jvp_fn( 483*da0073e9SAndroid Build Coastguard Worker wrapped_fn, input_to_perturb, eps, nbhd_checks_fn 484*da0073e9SAndroid Build Coastguard Worker ) 485*da0073e9SAndroid Build Coastguard Worker jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input( 486*da0073e9SAndroid Build Coastguard Worker jvp_fn, eps, x.is_complex(), is_forward_ad 487*da0073e9SAndroid Build Coastguard Worker ) 488*da0073e9SAndroid Build Coastguard Worker return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel()) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_jacobian_forward_ad( 492*da0073e9SAndroid Build Coastguard Worker fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None 493*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[torch.Tensor, ...], ...]: 494*da0073e9SAndroid Build Coastguard Worker """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker Return N * M Jacobians where N is the number of tensors in target that require grad and 497*da0073e9SAndroid Build Coastguard Worker M is the number of non-integral outputs. 498*da0073e9SAndroid Build Coastguard Worker Contrary to other functions here, this function requires "inputs" to actually be used by the function. 499*da0073e9SAndroid Build Coastguard Worker The computed value is expected to be wrong if the function captures the inputs by side effect instead of 500*da0073e9SAndroid Build Coastguard Worker using the passed ones (many torch.nn tests do this). 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker Args: 503*da0073e9SAndroid Build Coastguard Worker fn: the function to compute the jacobian for 504*da0073e9SAndroid Build Coastguard Worker inputs: inputs to `fn` 505*da0073e9SAndroid Build Coastguard Worker outputs: provide precomputed outputs to avoid one extra invocation of fn 506*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes: if True, will check that the gradient dtype are valid 507*da0073e9SAndroid Build Coastguard Worker all_u (optional): if provided, the Jacobian will be right multiplied with this vector 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker Returns: 510*da0073e9SAndroid Build Coastguard Worker A tuple of M N-tuples of tensors 511*da0073e9SAndroid Build Coastguard Worker """ 512*da0073e9SAndroid Build Coastguard Worker # To avoid early import issues 513*da0073e9SAndroid Build Coastguard Worker fwAD = torch.autograd.forward_ad 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker if any(i.is_complex() for i in tensor_inputs): 518*da0073e9SAndroid Build Coastguard Worker raise ValueError( 519*da0073e9SAndroid Build Coastguard Worker "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad." 520*da0073e9SAndroid Build Coastguard Worker ) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker if all_u: 523*da0073e9SAndroid Build Coastguard Worker jacobians = tuple( 524*da0073e9SAndroid Build Coastguard Worker _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs 525*da0073e9SAndroid Build Coastguard Worker ) 526*da0073e9SAndroid Build Coastguard Worker else: 527*da0073e9SAndroid Build Coastguard Worker jacobians = tuple( 528*da0073e9SAndroid Build Coastguard Worker _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs 529*da0073e9SAndroid Build Coastguard Worker ) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 532*da0073e9SAndroid Build Coastguard Worker fw_grads = [] 533*da0073e9SAndroid Build Coastguard Worker dual_inputs = [] 534*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(inputs): 535*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(inp) and inp.requires_grad: 536*da0073e9SAndroid Build Coastguard Worker if inp.layout == torch._mkldnn: # type: ignore[attr-defined] 537*da0073e9SAndroid Build Coastguard Worker raise ValueError( 538*da0073e9SAndroid Build Coastguard Worker "MKLDNN inputs are not support for forward AD gradcheck." 539*da0073e9SAndroid Build Coastguard Worker ) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 542*da0073e9SAndroid Build Coastguard Worker # If inp is a differentiable view, the dual might not be the tangent given to 543*da0073e9SAndroid Build Coastguard Worker # make_dual, so read it explicitly from the dual tensor 544*da0073e9SAndroid Build Coastguard Worker fw_grads.append(fwAD.unpack_dual(inp)[1]) 545*da0073e9SAndroid Build Coastguard Worker dual_inputs.append(inp) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker if all_u: 548*da0073e9SAndroid Build Coastguard Worker # Do the full reduction in one pass 549*da0073e9SAndroid Build Coastguard Worker # To be consistent with numerical evaluation, we actually compute one reduction per input 550*da0073e9SAndroid Build Coastguard Worker for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): 551*da0073e9SAndroid Build Coastguard Worker fw_grad.copy_(u.view_as(fw_grad)) 552*da0073e9SAndroid Build Coastguard Worker raw_outputs = _as_tuple(fn(*dual_inputs)) 553*da0073e9SAndroid Build Coastguard Worker dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) 554*da0073e9SAndroid Build Coastguard Worker for index_o, d_o in enumerate(dual_outputs): 555*da0073e9SAndroid Build Coastguard Worker val, res = fwAD.unpack_dual(d_o) 556*da0073e9SAndroid Build Coastguard Worker if ( 557*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes 558*da0073e9SAndroid Build Coastguard Worker and res is not None 559*da0073e9SAndroid Build Coastguard Worker and val.is_complex() != res.is_complex() 560*da0073e9SAndroid Build Coastguard Worker ): 561*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("Forward AD gradient has dtype mismatch.") 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker # Remove extra dimension of size 1 corresponding to the reduced input 564*da0073e9SAndroid Build Coastguard Worker jacobians[i][index_o].squeeze_(0) 565*da0073e9SAndroid Build Coastguard Worker if res is None: 566*da0073e9SAndroid Build Coastguard Worker jacobians[i][index_o].zero_() 567*da0073e9SAndroid Build Coastguard Worker else: 568*da0073e9SAndroid Build Coastguard Worker jacobians[i][index_o].copy_(res.reshape(-1)) 569*da0073e9SAndroid Build Coastguard Worker fw_grad.zero_() 570*da0073e9SAndroid Build Coastguard Worker else: 571*da0073e9SAndroid Build Coastguard Worker # Reconstruct the full Jacobian column by column 572*da0073e9SAndroid Build Coastguard Worker for i, fw_grad in enumerate(fw_grads): 573*da0073e9SAndroid Build Coastguard Worker for lin_idx, grad_idx in enumerate( 574*da0073e9SAndroid Build Coastguard Worker product(*[range(m) for m in fw_grad.size()]) 575*da0073e9SAndroid Build Coastguard Worker ): 576*da0073e9SAndroid Build Coastguard Worker fw_grad[grad_idx] = 1.0 577*da0073e9SAndroid Build Coastguard Worker raw_outputs = _as_tuple(fn(*dual_inputs)) 578*da0073e9SAndroid Build Coastguard Worker dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) 579*da0073e9SAndroid Build Coastguard Worker for index_o, d_o in enumerate(dual_outputs): 580*da0073e9SAndroid Build Coastguard Worker val, res = fwAD.unpack_dual(d_o) 581*da0073e9SAndroid Build Coastguard Worker if ( 582*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes 583*da0073e9SAndroid Build Coastguard Worker and res is not None 584*da0073e9SAndroid Build Coastguard Worker and val.is_complex() != res.is_complex() 585*da0073e9SAndroid Build Coastguard Worker ): 586*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 587*da0073e9SAndroid Build Coastguard Worker "Forward AD gradient has dtype mismatch." 588*da0073e9SAndroid Build Coastguard Worker ) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker if res is None: 591*da0073e9SAndroid Build Coastguard Worker jacobians[i][index_o][lin_idx].zero_() 592*da0073e9SAndroid Build Coastguard Worker else: 593*da0073e9SAndroid Build Coastguard Worker jacobians[i][index_o][lin_idx].copy_(res.reshape(-1)) 594*da0073e9SAndroid Build Coastguard Worker fw_grad[grad_idx] = 0.0 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker return jacobians 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Workerdef _get_input_to_perturb(input): 600*da0073e9SAndroid Build Coastguard Worker # Prepare the input so that it can be modified in-place and do certain 601*da0073e9SAndroid Build Coastguard Worker # operations that require the tensor to have strides. If fast_mode=False, 602*da0073e9SAndroid Build Coastguard Worker # _iter_tensor would handle the below cases: 603*da0073e9SAndroid Build Coastguard Worker if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn 604*da0073e9SAndroid Build Coastguard Worker # Convert to dense so we can perform operations that require strided tensors 605*da0073e9SAndroid Build Coastguard Worker input_to_perturb = input.to_dense() 606*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_any_tensor(input): 607*da0073e9SAndroid Build Coastguard Worker # Clone because input may require grad, and copy_ calls resize_, 608*da0073e9SAndroid Build Coastguard Worker # which is not allowed for .data 609*da0073e9SAndroid Build Coastguard Worker input_to_perturb = input.clone() 610*da0073e9SAndroid Build Coastguard Worker else: 611*da0073e9SAndroid Build Coastguard Worker input_to_perturb = input.data 612*da0073e9SAndroid Build Coastguard Worker return input_to_perturb 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Workerdef _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False): 616*da0073e9SAndroid Build Coastguard Worker # Wraps `fn` so that its inputs are already supplied 617*da0073e9SAndroid Build Coastguard Worker def wrapped_fn(): 618*da0073e9SAndroid Build Coastguard Worker inp = tuple( 619*da0073e9SAndroid Build Coastguard Worker _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) 620*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(a) 621*da0073e9SAndroid Build Coastguard Worker else a 622*da0073e9SAndroid Build Coastguard Worker for i, a in enumerate(_as_tuple(inputs)) 623*da0073e9SAndroid Build Coastguard Worker ) 624*da0073e9SAndroid Build Coastguard Worker return tuple(a.clone() for a in _as_tuple(fn(*inp))) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker return wrapped_fn 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn): 630*da0073e9SAndroid Build Coastguard Worker # Wraps jvp_fn so that certain arguments are already supplied 631*da0073e9SAndroid Build Coastguard Worker def jvp_fn(delta): 632*da0073e9SAndroid Build Coastguard Worker return _compute_numerical_gradient( 633*da0073e9SAndroid Build Coastguard Worker wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn 634*da0073e9SAndroid Build Coastguard Worker ) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker return jvp_fn 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Workerdef _reshape_tensor_or_tuple(u, shape): 640*da0073e9SAndroid Build Coastguard Worker # We don't need to reshape when input corresponding to u is sparse 641*da0073e9SAndroid Build Coastguard Worker if isinstance(u, tuple): 642*da0073e9SAndroid Build Coastguard Worker if not _is_sparse_any_tensor(u[0]): 643*da0073e9SAndroid Build Coastguard Worker return (u[0].reshape(shape), u[1].reshape(shape)) 644*da0073e9SAndroid Build Coastguard Worker else: 645*da0073e9SAndroid Build Coastguard Worker if not _is_sparse_any_tensor(u): 646*da0073e9SAndroid Build Coastguard Worker return u.reshape(shape) 647*da0073e9SAndroid Build Coastguard Worker return u 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Workerdef _mul_tensor_or_tuple(u, k): 651*da0073e9SAndroid Build Coastguard Worker if isinstance(u, tuple): 652*da0073e9SAndroid Build Coastguard Worker return (k * u[0], k * u[1]) 653*da0073e9SAndroid Build Coastguard Worker else: 654*da0073e9SAndroid Build Coastguard Worker return k * u 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_jvp_wrt_specific_input( 658*da0073e9SAndroid Build Coastguard Worker fn, input_idx, inputs, u, eps, is_forward_ad=False 659*da0073e9SAndroid Build Coastguard Worker) -> List[torch.Tensor]: 660*da0073e9SAndroid Build Coastguard Worker input = inputs[input_idx] 661*da0073e9SAndroid Build Coastguard Worker input_to_perturb = _get_input_to_perturb(input) 662*da0073e9SAndroid Build Coastguard Worker wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True) 663*da0073e9SAndroid Build Coastguard Worker nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps) 664*da0073e9SAndroid Build Coastguard Worker jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) 665*da0073e9SAndroid Build Coastguard Worker u = _reshape_tensor_or_tuple(u, input_to_perturb.shape) 666*da0073e9SAndroid Build Coastguard Worker u = _mul_tensor_or_tuple(u, eps) 667*da0073e9SAndroid Build Coastguard Worker return _compute_numerical_jvps_wrt_specific_input( 668*da0073e9SAndroid Build Coastguard Worker jvp_fn, u, input.is_complex(), is_forward_ad 669*da0073e9SAndroid Build Coastguard Worker ) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Workerdef _get_numerical_vJu( 673*da0073e9SAndroid Build Coastguard Worker fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad 674*da0073e9SAndroid Build Coastguard Worker): 675*da0073e9SAndroid Build Coastguard Worker # Note that all_v can also be None, in that case, this function only computes Ju. 676*da0073e9SAndroid Build Coastguard Worker reduced_jacobians: List[List[torch.Tensor]] = [] 677*da0073e9SAndroid Build Coastguard Worker for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)): 678*da0073e9SAndroid Build Coastguard Worker all_Ju = _get_numerical_jvp_wrt_specific_input( 679*da0073e9SAndroid Build Coastguard Worker fn, inp_idx, inputs, u, eps, is_forward_ad 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker # Filter out the Ju for non floating point outputs 682*da0073e9SAndroid Build Coastguard Worker filtered_Ju = [] 683*da0073e9SAndroid Build Coastguard Worker func_out = _as_tuple(func_out) 684*da0073e9SAndroid Build Coastguard Worker assert len(all_Ju) == len(func_out) 685*da0073e9SAndroid Build Coastguard Worker for Ju, output in zip(all_Ju, func_out): 686*da0073e9SAndroid Build Coastguard Worker if _is_float_or_complex_tensor(output): 687*da0073e9SAndroid Build Coastguard Worker filtered_Ju.append(Ju) 688*da0073e9SAndroid Build Coastguard Worker else: 689*da0073e9SAndroid Build Coastguard Worker # TODO: handle the other Ju 690*da0073e9SAndroid Build Coastguard Worker pass 691*da0073e9SAndroid Build Coastguard Worker if all_v is not None: 692*da0073e9SAndroid Build Coastguard Worker jacobian_scalars: List[torch.Tensor] = [] 693*da0073e9SAndroid Build Coastguard Worker for v, Ju in zip(all_v, filtered_Ju): 694*da0073e9SAndroid Build Coastguard Worker jacobian_scalars.append(_dot_with_type_promotion(v, Ju)) 695*da0073e9SAndroid Build Coastguard Worker reduced_jacobians.append(jacobian_scalars) 696*da0073e9SAndroid Build Coastguard Worker else: 697*da0073e9SAndroid Build Coastguard Worker reduced_jacobians.append(filtered_Ju) 698*da0073e9SAndroid Build Coastguard Worker return reduced_jacobians 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Workerdef _check_jacobians_equal(j1, j2, atol): 702*da0073e9SAndroid Build Coastguard Worker # Check whether the max difference between two Jacobian tensors are within some 703*da0073e9SAndroid Build Coastguard Worker # tolerance `atol`. 704*da0073e9SAndroid Build Coastguard Worker for j1_x, j2_x in zip(j1, j2): 705*da0073e9SAndroid Build Coastguard Worker if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol: 706*da0073e9SAndroid Build Coastguard Worker return False 707*da0073e9SAndroid Build Coastguard Worker return True 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Workerdef _stack_and_check_tensors( 711*da0073e9SAndroid Build Coastguard Worker list_of_list_of_tensors, inputs, numel_outputs 712*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]: 713*da0073e9SAndroid Build Coastguard Worker # For the ith tensor in the inner list checks whether it has the same size and 714*da0073e9SAndroid Build Coastguard Worker # dtype as the ith differentiable input. 715*da0073e9SAndroid Build Coastguard Worker out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs) 716*da0073e9SAndroid Build Coastguard Worker diff_input_list = list(_iter_tensors(inputs, True)) 717*da0073e9SAndroid Build Coastguard Worker correct_grad_sizes = True 718*da0073e9SAndroid Build Coastguard Worker correct_grad_types = True 719*da0073e9SAndroid Build Coastguard Worker for i, tensor_list in enumerate(list_of_list_of_tensors): 720*da0073e9SAndroid Build Coastguard Worker inp = diff_input_list[i] 721*da0073e9SAndroid Build Coastguard Worker out_jacobian = out_jacobians[i] 722*da0073e9SAndroid Build Coastguard Worker for j, tensor in enumerate(tensor_list): 723*da0073e9SAndroid Build Coastguard Worker if tensor is not None and tensor.size() != inp.size(): 724*da0073e9SAndroid Build Coastguard Worker correct_grad_sizes = False 725*da0073e9SAndroid Build Coastguard Worker elif tensor is not None and tensor.dtype != inp.dtype: 726*da0073e9SAndroid Build Coastguard Worker correct_grad_types = False 727*da0073e9SAndroid Build Coastguard Worker if tensor is None: 728*da0073e9SAndroid Build Coastguard Worker out_jacobian[:, j].zero_() 729*da0073e9SAndroid Build Coastguard Worker else: 730*da0073e9SAndroid Build Coastguard Worker dense = ( 731*da0073e9SAndroid Build Coastguard Worker tensor.to_dense() if not tensor.layout == torch.strided else tensor 732*da0073e9SAndroid Build Coastguard Worker ) 733*da0073e9SAndroid Build Coastguard Worker assert out_jacobian[:, j].numel() == dense.numel() 734*da0073e9SAndroid Build Coastguard Worker out_jacobian[:, j] = dense.reshape(-1) 735*da0073e9SAndroid Build Coastguard Worker return out_jacobians, correct_grad_sizes, correct_grad_types 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard WorkerFAILED_NONDET_MSG = """\n 739*da0073e9SAndroid Build Coastguard WorkerNOTE: If your op relies on non-deterministic operations i.e., it is listed here: 740*da0073e9SAndroid Build Coastguard Workerhttps://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html 741*da0073e9SAndroid Build Coastguard Workerthis failure might be expected. 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the 744*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 745*da0073e9SAndroid Build Coastguard WorkerIf the test 746*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 747*da0073e9SAndroid Build Coastguard Worker with `nondet_tol=<tol>` as a keyword argument. 748*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 749*da0073e9SAndroid Build Coastguard Worker to have `gradcheck_nondet_tol=<tol>`. 750*da0073e9SAndroid Build Coastguard Worker- is a Module test (e.g., in common_nn.py), then modify the corresponding 751*da0073e9SAndroid Build Coastguard Worker module_test entry to have `gradcheck_nondet_tol=<tol>` 752*da0073e9SAndroid Build Coastguard Worker""" 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Workerdef _check_analytical_jacobian_attributes( 756*da0073e9SAndroid Build Coastguard Worker inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None 757*da0073e9SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, ...]: 758*da0073e9SAndroid Build Coastguard Worker # This is used by both fast and slow mode: 759*da0073e9SAndroid Build Coastguard Worker # - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith 760*da0073e9SAndroid Build Coastguard Worker # input. 761*da0073e9SAndroid Build Coastguard Worker # - For fast mode, vjps[i][0] is a linear combination of the rows 762*da0073e9SAndroid Build Coastguard Worker # of the Jacobian wrt the ith input 763*da0073e9SAndroid Build Coastguard Worker diff_input_list = list(_iter_tensors(inputs, True)) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def vjp_fn(grad_output): 766*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad( 767*da0073e9SAndroid Build Coastguard Worker output, diff_input_list, grad_output, retain_graph=True, allow_unused=True 768*da0073e9SAndroid Build Coastguard Worker ) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker # Compute everything twice to check for nondeterminism (which we call reentrancy) 771*da0073e9SAndroid Build Coastguard Worker if fast_mode: 772*da0073e9SAndroid Build Coastguard Worker vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) 773*da0073e9SAndroid Build Coastguard Worker vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) 774*da0073e9SAndroid Build Coastguard Worker else: 775*da0073e9SAndroid Build Coastguard Worker vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 776*da0073e9SAndroid Build Coastguard Worker vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker output_numel = output.numel() if not fast_mode else 1 779*da0073e9SAndroid Build Coastguard Worker jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( 780*da0073e9SAndroid Build Coastguard Worker vjps1, inputs, output_numel 781*da0073e9SAndroid Build Coastguard Worker ) 782*da0073e9SAndroid Build Coastguard Worker jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) 783*da0073e9SAndroid Build Coastguard Worker reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker if not types_ok and check_grad_dtypes: 786*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("Gradient has dtype mismatch") 787*da0073e9SAndroid Build Coastguard Worker if not sizes_ok: 788*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("Analytical gradient has incorrect size") 789*da0073e9SAndroid Build Coastguard Worker if not reentrant: 790*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 791*da0073e9SAndroid Build Coastguard Worker "Backward is not reentrant, i.e., running backward with " 792*da0073e9SAndroid Build Coastguard Worker "same input and grad_output multiple times gives different values, " 793*da0073e9SAndroid Build Coastguard Worker "although analytical gradient matches numerical gradient." 794*da0073e9SAndroid Build Coastguard Worker f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG 795*da0073e9SAndroid Build Coastguard Worker ) 796*da0073e9SAndroid Build Coastguard Worker return jacobians1 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_vJu_backward_mode( 800*da0073e9SAndroid Build Coastguard Worker inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u 801*da0073e9SAndroid Build Coastguard Worker): 802*da0073e9SAndroid Build Coastguard Worker reduced_jacobians: List[List[torch.Tensor]] = [] 803*da0073e9SAndroid Build Coastguard Worker for output, v in zip(outputs, all_v): 804*da0073e9SAndroid Build Coastguard Worker all_vJ = _check_analytical_jacobian_attributes( 805*da0073e9SAndroid Build Coastguard Worker inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v 806*da0073e9SAndroid Build Coastguard Worker ) 807*da0073e9SAndroid Build Coastguard Worker jacobian_scalars: List[torch.Tensor] = [] 808*da0073e9SAndroid Build Coastguard Worker for vJ, u in zip(all_vJ, all_u): 809*da0073e9SAndroid Build Coastguard Worker # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse 810*da0073e9SAndroid Build Coastguard Worker # the error checking logic from slow mode 811*da0073e9SAndroid Build Coastguard Worker vJ = vJ.T.squeeze(0) 812*da0073e9SAndroid Build Coastguard Worker if vJ.is_complex(): # C -> R 813*da0073e9SAndroid Build Coastguard Worker tv = torch.view_as_real(vJ.resolve_conj()) 814*da0073e9SAndroid Build Coastguard Worker tr = tv.select(-1, 0) 815*da0073e9SAndroid Build Coastguard Worker ti = tv.select(-1, 1) 816*da0073e9SAndroid Build Coastguard Worker jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) 817*da0073e9SAndroid Build Coastguard Worker else: # R -> R 818*da0073e9SAndroid Build Coastguard Worker jacobian_scalars.append(vJ.dot(u)) 819*da0073e9SAndroid Build Coastguard Worker reduced_jacobians.append(jacobian_scalars) 820*da0073e9SAndroid Build Coastguard Worker return reduced_jacobians 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker@deprecated( 824*da0073e9SAndroid Build Coastguard Worker "`get_analytical_jacobian` was part of PyTorch's private API and not " 825*da0073e9SAndroid Build Coastguard Worker "meant to be exposed. We are deprecating it and it will be removed " 826*da0073e9SAndroid Build Coastguard Worker "in a future version of PyTorch. If you have a specific use for " 827*da0073e9SAndroid Build Coastguard Worker "this or feature request for this to be a stable API, please file " 828*da0073e9SAndroid Build Coastguard Worker "us an issue at https://github.com/pytorch/pytorch/issues/new", 829*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 830*da0073e9SAndroid Build Coastguard Worker) 831*da0073e9SAndroid Build Coastguard Workerdef get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): 832*da0073e9SAndroid Build Coastguard Worker # Replicates the behavior of the old get_analytical_jacobian before the refactor 833*da0073e9SAndroid Build Coastguard Worker # This shares much of its code with _check_analytical_jacobian_attributes 834*da0073e9SAndroid Build Coastguard Worker if ( 835*da0073e9SAndroid Build Coastguard Worker grad_out != 1.0 836*da0073e9SAndroid Build Coastguard Worker ): # grad_out param is only kept for backward compatibility reasons 837*da0073e9SAndroid Build Coastguard Worker raise ValueError( 838*da0073e9SAndroid Build Coastguard Worker "Expected grad_out to be 1.0. get_analytical_jacobian no longer " 839*da0073e9SAndroid Build Coastguard Worker "supports values of grad_out != 1.0." 840*da0073e9SAndroid Build Coastguard Worker ) 841*da0073e9SAndroid Build Coastguard Worker if output.is_complex(): 842*da0073e9SAndroid Build Coastguard Worker raise ValueError( 843*da0073e9SAndroid Build Coastguard Worker "Expected output to be non-complex. get_analytical_jacobian no " 844*da0073e9SAndroid Build Coastguard Worker "longer supports functions that return complex outputs." 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker diff_input_list = list(_iter_tensors(inputs, True)) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def vjp_fn(grad_output): 849*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad( 850*da0073e9SAndroid Build Coastguard Worker output, diff_input_list, grad_output, retain_graph=True, allow_unused=True 851*da0073e9SAndroid Build Coastguard Worker ) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker # Compute everything twice to check for nondeterminism (which we call reentrancy) 854*da0073e9SAndroid Build Coastguard Worker vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 855*da0073e9SAndroid Build Coastguard Worker vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker output_numel = output.numel() 858*da0073e9SAndroid Build Coastguard Worker jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( 859*da0073e9SAndroid Build Coastguard Worker vjps1, inputs, output_numel 860*da0073e9SAndroid Build Coastguard Worker ) 861*da0073e9SAndroid Build Coastguard Worker jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) 862*da0073e9SAndroid Build Coastguard Worker reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker return jacobians1, reentrant, sizes_ok, types_ok 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_jacobian(inputs, outputs, input_idx, output_idx): 868*da0073e9SAndroid Build Coastguard Worker # Computes the analytical Jacobian in slow mode for a single input-output pair. 869*da0073e9SAndroid Build Coastguard Worker # Forgoes performing checks on dtype, shape, and reentrancy. 870*da0073e9SAndroid Build Coastguard Worker jacobians = _check_analytical_jacobian_attributes( 871*da0073e9SAndroid Build Coastguard Worker inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False 872*da0073e9SAndroid Build Coastguard Worker ) 873*da0073e9SAndroid Build Coastguard Worker return jacobians[input_idx] 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Workerdef _compute_analytical_jacobian_rows( 877*da0073e9SAndroid Build Coastguard Worker vjp_fn, sample_output 878*da0073e9SAndroid Build Coastguard Worker) -> List[List[Optional[torch.Tensor]]]: 879*da0073e9SAndroid Build Coastguard Worker # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis 880*da0073e9SAndroid Build Coastguard Worker # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian. 881*da0073e9SAndroid Build Coastguard Worker # NB: this function does not assume vjp_fn(v) to return tensors with the same 882*da0073e9SAndroid Build Coastguard Worker # number of elements for different v. This is checked when we later combine the 883*da0073e9SAndroid Build Coastguard Worker # rows into a single tensor. 884*da0073e9SAndroid Build Coastguard Worker grad_out_base = torch.zeros_like( 885*da0073e9SAndroid Build Coastguard Worker sample_output, memory_format=torch.legacy_contiguous_format 886*da0073e9SAndroid Build Coastguard Worker ) 887*da0073e9SAndroid Build Coastguard Worker flat_grad_out = grad_out_base.view(-1) 888*da0073e9SAndroid Build Coastguard Worker # jacobians_rows[i][j] is the Jacobian jth row for the ith input 889*da0073e9SAndroid Build Coastguard Worker jacobians_rows: List[List[Optional[torch.Tensor]]] = [] 890*da0073e9SAndroid Build Coastguard Worker for j in range(flat_grad_out.numel()): 891*da0073e9SAndroid Build Coastguard Worker flat_grad_out.zero_() 892*da0073e9SAndroid Build Coastguard Worker flat_grad_out[j] = 1.0 # projection for jth row of Jacobian 893*da0073e9SAndroid Build Coastguard Worker grad_inputs = vjp_fn(grad_out_base) 894*da0073e9SAndroid Build Coastguard Worker for i, d_x in enumerate(grad_inputs): 895*da0073e9SAndroid Build Coastguard Worker if j == 0: 896*da0073e9SAndroid Build Coastguard Worker jacobians_rows.append([]) 897*da0073e9SAndroid Build Coastguard Worker jacobians_rows[i] += [ 898*da0073e9SAndroid Build Coastguard Worker d_x.clone() if isinstance(d_x, torch.Tensor) else None 899*da0073e9SAndroid Build Coastguard Worker ] 900*da0073e9SAndroid Build Coastguard Worker return jacobians_rows 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Workerdef _get_analytical_vjps_wrt_specific_output( 904*da0073e9SAndroid Build Coastguard Worker vjp_fn, sample_output, v 905*da0073e9SAndroid Build Coastguard Worker) -> List[List[Optional[torch.Tensor]]]: 906*da0073e9SAndroid Build Coastguard Worker vjps: List[List[Optional[torch.Tensor]]] = [] 907*da0073e9SAndroid Build Coastguard Worker grad_inputs = vjp_fn(v.reshape(sample_output.shape)) 908*da0073e9SAndroid Build Coastguard Worker for vjp in grad_inputs: 909*da0073e9SAndroid Build Coastguard Worker vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None]) 910*da0073e9SAndroid Build Coastguard Worker return vjps 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Workerdef _check_inputs(tupled_inputs) -> bool: 914*da0073e9SAndroid Build Coastguard Worker # Make sure that gradients are saved for at least one input 915*da0073e9SAndroid Build Coastguard Worker any_input_requiring_grad = False 916*da0073e9SAndroid Build Coastguard Worker for idx, inp in enumerate(tupled_inputs): 917*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(inp) and inp.requires_grad: 918*da0073e9SAndroid Build Coastguard Worker if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): 919*da0073e9SAndroid Build Coastguard Worker warnings.warn( 920*da0073e9SAndroid Build Coastguard Worker f"Input #{idx} requires gradient and " 921*da0073e9SAndroid Build Coastguard Worker "is not a double precision floating point or complex. " 922*da0073e9SAndroid Build Coastguard Worker "This check will likely fail if all the inputs are " 923*da0073e9SAndroid Build Coastguard Worker "not of double precision floating point or complex. " 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker if inp.is_sparse: 926*da0073e9SAndroid Build Coastguard Worker content = inp._values() 927*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_compressed_tensor(inp): 928*da0073e9SAndroid Build Coastguard Worker content = inp.values() 929*da0073e9SAndroid Build Coastguard Worker else: 930*da0073e9SAndroid Build Coastguard Worker content = inp 931*da0073e9SAndroid Build Coastguard Worker # TODO: To cover more problematic cases, replace stride = 0 check with 932*da0073e9SAndroid Build Coastguard Worker # "any overlap in memory" once we have a proper function to check it. 933*da0073e9SAndroid Build Coastguard Worker if content.layout is not torch._mkldnn: # type: ignore[attr-defined] 934*da0073e9SAndroid Build Coastguard Worker if not all( 935*da0073e9SAndroid Build Coastguard Worker st > 0 or sz <= 1 936*da0073e9SAndroid Build Coastguard Worker for st, sz in zip(content.stride(), content.size()) 937*da0073e9SAndroid Build Coastguard Worker ): 938*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 939*da0073e9SAndroid Build Coastguard Worker f"The {idx}th input has a dimension with stride 0. gradcheck only " 940*da0073e9SAndroid Build Coastguard Worker "supports inputs that are non-overlapping to be able to " 941*da0073e9SAndroid Build Coastguard Worker "compute the numerical gradients correctly. You should call " 942*da0073e9SAndroid Build Coastguard Worker ".contiguous on the input before passing it to gradcheck." 943*da0073e9SAndroid Build Coastguard Worker ) 944*da0073e9SAndroid Build Coastguard Worker any_input_requiring_grad = True 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker if not any_input_requiring_grad: 947*da0073e9SAndroid Build Coastguard Worker raise ValueError( 948*da0073e9SAndroid Build Coastguard Worker "gradcheck expects at least one input tensor to require gradient, " 949*da0073e9SAndroid Build Coastguard Worker "but none of the them have requires_grad=True." 950*da0073e9SAndroid Build Coastguard Worker ) 951*da0073e9SAndroid Build Coastguard Worker return True 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Workerdef _check_outputs(outputs) -> None: 955*da0073e9SAndroid Build Coastguard Worker if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): 956*da0073e9SAndroid Build Coastguard Worker # it is easier to call to_dense() on the sparse output than 957*da0073e9SAndroid Build Coastguard Worker # to modify analytical jacobian 958*da0073e9SAndroid Build Coastguard Worker raise ValueError( 959*da0073e9SAndroid Build Coastguard Worker "Sparse output is not supported at gradcheck yet. " 960*da0073e9SAndroid Build Coastguard Worker "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] 963*da0073e9SAndroid Build Coastguard Worker raise ValueError( 964*da0073e9SAndroid Build Coastguard Worker "MKLDNN output is not supported at gradcheck yet. " 965*da0073e9SAndroid Build Coastguard Worker "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." 966*da0073e9SAndroid Build Coastguard Worker ) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Workerdef _check_no_differentiable_outputs( 970*da0073e9SAndroid Build Coastguard Worker func, inputs, func_out, eps, *, is_forward_ad 971*da0073e9SAndroid Build Coastguard Worker) -> bool: 972*da0073e9SAndroid Build Coastguard Worker # When there are no differentiable outputs, numerical gradient for a function is 973*da0073e9SAndroid Build Coastguard Worker # expected to be zero. 974*da0073e9SAndroid Build Coastguard Worker jacobians_all_inputs_outputs = _get_numerical_jacobian( 975*da0073e9SAndroid Build Coastguard Worker func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad 976*da0073e9SAndroid Build Coastguard Worker ) 977*da0073e9SAndroid Build Coastguard Worker for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs: 978*da0073e9SAndroid Build Coastguard Worker for jacobian in jacobians_all_outputs_and_fixed_input: 979*da0073e9SAndroid Build Coastguard Worker if torch.ne(jacobian, 0).sum() > 0: 980*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 981*da0073e9SAndroid Build Coastguard Worker "Numerical gradient for function expected to be zero" 982*da0073e9SAndroid Build Coastguard Worker ) 983*da0073e9SAndroid Build Coastguard Worker return True 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Workerdef _check_no_differentiable_outputs_fast( 987*da0073e9SAndroid Build Coastguard Worker func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol 988*da0073e9SAndroid Build Coastguard Worker): 989*da0073e9SAndroid Build Coastguard Worker for inp_idx, u in zip(inputs_indices, all_u): 990*da0073e9SAndroid Build Coastguard Worker jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps) 991*da0073e9SAndroid Build Coastguard Worker for jvp in jvps: 992*da0073e9SAndroid Build Coastguard Worker if jvp.numel() == 0: 993*da0073e9SAndroid Build Coastguard Worker continue 994*da0073e9SAndroid Build Coastguard Worker if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol: 995*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 996*da0073e9SAndroid Build Coastguard Worker "Numerical gradient for function expected to be zero" 997*da0073e9SAndroid Build Coastguard Worker ) 998*da0073e9SAndroid Build Coastguard Worker return True 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard WorkerFAILED_BATCHED_GRAD_MSG = """ 1002*da0073e9SAndroid Build Coastguard Workergradcheck or gradgradcheck failed while testing batched gradient computation. 1003*da0073e9SAndroid Build Coastguard WorkerThis could have been invoked in a number of ways (via a test that calls 1004*da0073e9SAndroid Build Coastguard Workergradcheck/gradgradcheck directly or via an autogenerated test). 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the 1007*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 1008*da0073e9SAndroid Build Coastguard WorkerIf the test 1009*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1010*da0073e9SAndroid Build Coastguard Worker with `check_batched_grad=False` as a keyword argument. 1011*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1012*da0073e9SAndroid Build Coastguard Worker to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`. 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard WorkerIf you're modifying an existing operator that supports batched grad computation, 1015*da0073e9SAndroid Build Coastguard Workeror wish to make a new operator work with batched grad computation, please read 1016*da0073e9SAndroid Build Coastguard Workerthe following. 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard WorkerTo compute batched grads (e.g., jacobians, hessians), we vmap over the backward 1019*da0073e9SAndroid Build Coastguard Workercomputation. The most common failure case is if there is a 'vmap-incompatible 1020*da0073e9SAndroid Build Coastguard Workeroperation' in the backward pass. Please see 1021*da0073e9SAndroid Build Coastguard WorkerNOTE: [How to write vmap-compatible backward formulas] 1022*da0073e9SAndroid Build Coastguard Workerin the codebase for an explanation of how to fix this. 1023*da0073e9SAndroid Build Coastguard Worker""".strip() 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard WorkerFAILED_BATCHED_GRAD_MSG_FWD_AD = """ 1026*da0073e9SAndroid Build Coastguard Workergradcheck failed while testing batched gradient computation with forward-mode AD. 1027*da0073e9SAndroid Build Coastguard WorkerThis test is enabled automatically when both `check_batched_grad=True` 1028*da0073e9SAndroid Build Coastguard Workerand `check_forward_ad=True`, but can be disabled in the following ways 1029*da0073e9SAndroid Build Coastguard Workerdependong on how the test was invoked (via a test that calls gradcheck 1030*da0073e9SAndroid Build Coastguard Workerdirectly or via an autogenerated test). 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the 1033*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 1034*da0073e9SAndroid Build Coastguard WorkerIf the test 1035*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1036*da0073e9SAndroid Build Coastguard Worker with `check_batched_forward_grad=False` as a keyword argument. 1037*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1038*da0073e9SAndroid Build Coastguard Worker to have `check_batched_forward_grad=False` 1039*da0073e9SAndroid Build Coastguard Worker""" 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Workerdef _get_failed_batched_grad_test_msg( 1043*da0073e9SAndroid Build Coastguard Worker output_idx, input_idx, res, exp, is_forward_ad=False 1044*da0073e9SAndroid Build Coastguard Worker): 1045*da0073e9SAndroid Build Coastguard Worker return f""" 1046*da0073e9SAndroid Build Coastguard WorkerFor output {output_idx} and input {input_idx}: 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG} 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard WorkerGot: 1051*da0073e9SAndroid Build Coastguard Worker{res} 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard WorkerExpected: 1054*da0073e9SAndroid Build Coastguard Worker{exp} 1055*da0073e9SAndroid Build Coastguard Worker""".strip() 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Workerdef _test_batched_grad_forward_ad(func, inputs) -> bool: 1059*da0073e9SAndroid Build Coastguard Worker fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) 1060*da0073e9SAndroid Build Coastguard Worker assert isinstance(inputs, tuple) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker for input_idx, current_input in enumerate(inputs): 1063*da0073e9SAndroid Build Coastguard Worker if not (is_tensor_like(current_input) and current_input.requires_grad): 1064*da0073e9SAndroid Build Coastguard Worker continue 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker def jvp(tangent: torch.Tensor): 1067*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 1068*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(current_input.detach(), tangent) 1069*da0073e9SAndroid Build Coastguard Worker inputs_with_dual = tuple( 1070*da0073e9SAndroid Build Coastguard Worker dual 1071*da0073e9SAndroid Build Coastguard Worker if idx == input_idx 1072*da0073e9SAndroid Build Coastguard Worker else (inp.detach() if is_tensor_like(inp) else inp) 1073*da0073e9SAndroid Build Coastguard Worker for idx, inp in enumerate(inputs) 1074*da0073e9SAndroid Build Coastguard Worker ) 1075*da0073e9SAndroid Build Coastguard Worker dual_outputs = _as_tuple(func(*inputs_with_dual)) 1076*da0073e9SAndroid Build Coastguard Worker ret = [] 1077*da0073e9SAndroid Build Coastguard Worker for dual_output in dual_outputs: 1078*da0073e9SAndroid Build Coastguard Worker if dual_output is None: 1079*da0073e9SAndroid Build Coastguard Worker continue 1080*da0073e9SAndroid Build Coastguard Worker primal_out, tangent_out = fwAD.unpack_dual(dual_output) 1081*da0073e9SAndroid Build Coastguard Worker if tangent_out is not None: 1082*da0073e9SAndroid Build Coastguard Worker ret.append(tangent_out) 1083*da0073e9SAndroid Build Coastguard Worker else: 1084*da0073e9SAndroid Build Coastguard Worker ret.append( 1085*da0073e9SAndroid Build Coastguard Worker torch.zeros( 1086*da0073e9SAndroid Build Coastguard Worker [], dtype=primal_out.dtype, device=primal_out.device 1087*da0073e9SAndroid Build Coastguard Worker ).expand(primal_out.shape) 1088*da0073e9SAndroid Build Coastguard Worker ) 1089*da0073e9SAndroid Build Coastguard Worker return tuple(ret) 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker if not _is_float_or_complex_tensor(current_input): 1092*da0073e9SAndroid Build Coastguard Worker continue 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker tangents = [torch.randn_like(current_input) for _ in range(2)] 1095*da0073e9SAndroid Build Coastguard Worker expected = [jvp(t) for t in tangents] 1096*da0073e9SAndroid Build Coastguard Worker expected = [torch.stack(shards) for shards in zip(*expected)] 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker try: 1099*da0073e9SAndroid Build Coastguard Worker result = _vmap(jvp)(torch.stack(tangents)) 1100*da0073e9SAndroid Build Coastguard Worker except RuntimeError as ex: 1101*da0073e9SAndroid Build Coastguard Worker # Rethrow to provide a better error message 1102*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1103*da0073e9SAndroid Build Coastguard Worker f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}" 1104*da0073e9SAndroid Build Coastguard Worker ) from ex 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker for input_idx, (res, exp) in enumerate(zip(result, expected)): 1107*da0073e9SAndroid Build Coastguard Worker if torch.allclose(res, exp): 1108*da0073e9SAndroid Build Coastguard Worker continue 1109*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1110*da0073e9SAndroid Build Coastguard Worker _get_failed_batched_grad_test_msg( 1111*da0073e9SAndroid Build Coastguard Worker input_idx, input_idx, res, exp, is_forward_ad=True 1112*da0073e9SAndroid Build Coastguard Worker ) 1113*da0073e9SAndroid Build Coastguard Worker ) 1114*da0073e9SAndroid Build Coastguard Worker return True 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Workerdef _test_batched_grad(input, output, output_idx) -> bool: 1118*da0073e9SAndroid Build Coastguard Worker # NB: _test_batched_grad compares two autograd.grad invocations with a single 1119*da0073e9SAndroid Build Coastguard Worker # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the 1120*da0073e9SAndroid Build Coastguard Worker # sense that we're not comparing an analytical jacobian with a numeric one, 1121*da0073e9SAndroid Build Coastguard Worker # but it is morally similar (we could have computed a full analytic jac 1122*da0073e9SAndroid Build Coastguard Worker # via vmap, but that is potentially slow) 1123*da0073e9SAndroid Build Coastguard Worker diff_input_list = list(_iter_tensors(input, True)) 1124*da0073e9SAndroid Build Coastguard Worker grad = functools.partial( 1125*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad, 1126*da0073e9SAndroid Build Coastguard Worker output, 1127*da0073e9SAndroid Build Coastguard Worker diff_input_list, 1128*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 1129*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 1130*da0073e9SAndroid Build Coastguard Worker ) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker def vjp(v): 1133*da0073e9SAndroid Build Coastguard Worker results = grad(v) 1134*da0073e9SAndroid Build Coastguard Worker results = tuple( 1135*da0073e9SAndroid Build Coastguard Worker grad 1136*da0073e9SAndroid Build Coastguard Worker if grad is not None 1137*da0073e9SAndroid Build Coastguard Worker else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) 1138*da0073e9SAndroid Build Coastguard Worker for grad, inp in zip(results, diff_input_list) 1139*da0073e9SAndroid Build Coastguard Worker ) 1140*da0073e9SAndroid Build Coastguard Worker return results 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker grad_outputs = [torch.randn_like(output) for _ in range(2)] 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker expected = [vjp(gO) for gO in grad_outputs] 1145*da0073e9SAndroid Build Coastguard Worker expected = [torch.stack(shards) for shards in zip(*expected)] 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker # Squash warnings since these are expected to happen in most cases 1148*da0073e9SAndroid Build Coastguard Worker # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 1149*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 1150*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore", message="There is a performance drop") 1151*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore", message="Please use torch.vmap") 1152*da0073e9SAndroid Build Coastguard Worker try: 1153*da0073e9SAndroid Build Coastguard Worker result = vmap(vjp)(torch.stack(grad_outputs)) 1154*da0073e9SAndroid Build Coastguard Worker except RuntimeError as ex: 1155*da0073e9SAndroid Build Coastguard Worker # It's OK that we're not raising the error at the correct callsite. 1156*da0073e9SAndroid Build Coastguard Worker # That's because the callsite is always going to inside the Python 1157*da0073e9SAndroid Build Coastguard Worker # autograd.grad instead of the C++ traceback of what line in the 1158*da0073e9SAndroid Build Coastguard Worker # backward formula 1159*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1160*da0073e9SAndroid Build Coastguard Worker f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}" 1161*da0073e9SAndroid Build Coastguard Worker ) from ex 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker for input_idx, (res, exp) in enumerate(zip(result, expected)): 1164*da0073e9SAndroid Build Coastguard Worker if torch.allclose(res, exp): 1165*da0073e9SAndroid Build Coastguard Worker continue 1166*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1167*da0073e9SAndroid Build Coastguard Worker _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp) 1168*da0073e9SAndroid Build Coastguard Worker ) 1169*da0073e9SAndroid Build Coastguard Worker return True 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Workerdef _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool: 1173*da0073e9SAndroid Build Coastguard Worker # Tests that backward is multiplied by grad_output 1174*da0073e9SAndroid Build Coastguard Worker diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) 1175*da0073e9SAndroid Build Coastguard Worker if not diff_input_list: 1176*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("no Tensors requiring grad found in input") 1177*da0073e9SAndroid Build Coastguard Worker grads_input = torch.autograd.grad( 1178*da0073e9SAndroid Build Coastguard Worker outputs, 1179*da0073e9SAndroid Build Coastguard Worker diff_input_list, 1180*da0073e9SAndroid Build Coastguard Worker [ 1181*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) 1182*da0073e9SAndroid Build Coastguard Worker for o in outputs 1183*da0073e9SAndroid Build Coastguard Worker ], 1184*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 1185*da0073e9SAndroid Build Coastguard Worker ) 1186*da0073e9SAndroid Build Coastguard Worker for gi, di in zip(grads_input, diff_input_list): 1187*da0073e9SAndroid Build Coastguard Worker if gi is None: 1188*da0073e9SAndroid Build Coastguard Worker continue 1189*da0073e9SAndroid Build Coastguard Worker if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: 1190*da0073e9SAndroid Build Coastguard Worker if gi.layout != di.layout: 1191*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1192*da0073e9SAndroid Build Coastguard Worker "grad is incorrect layout (" 1193*da0073e9SAndroid Build Coastguard Worker + str(gi.layout) 1194*da0073e9SAndroid Build Coastguard Worker + " is not " 1195*da0073e9SAndroid Build Coastguard Worker + str(di.layout) 1196*da0073e9SAndroid Build Coastguard Worker + ")" 1197*da0073e9SAndroid Build Coastguard Worker ) 1198*da0073e9SAndroid Build Coastguard Worker if _is_sparse_any_tensor(gi): 1199*da0073e9SAndroid Build Coastguard Worker sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "") 1200*da0073e9SAndroid Build Coastguard Worker if gi.sparse_dim() != di.sparse_dim(): 1201*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1202*da0073e9SAndroid Build Coastguard Worker f"grad is {sparse_kind} tensor, but has incorrect sparse_dim" 1203*da0073e9SAndroid Build Coastguard Worker f" {gi.sparse_dim()}, expected {di.sparse_dim()}" 1204*da0073e9SAndroid Build Coastguard Worker ) 1205*da0073e9SAndroid Build Coastguard Worker if gi.dense_dim() != di.dense_dim(): 1206*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1207*da0073e9SAndroid Build Coastguard Worker f"grad is {sparse_kind} tensor, but has incorrect dense_dim" 1208*da0073e9SAndroid Build Coastguard Worker f" {gi.dense_dim()}, expected {di.dense_dim()}" 1209*da0073e9SAndroid Build Coastguard Worker ) 1210*da0073e9SAndroid Build Coastguard Worker gi = gi.to_dense() 1211*da0073e9SAndroid Build Coastguard Worker di = di.to_dense() 1212*da0073e9SAndroid Build Coastguard Worker if masked: 1213*da0073e9SAndroid Build Coastguard Worker if not torch.allclose(gi, torch.zeros_like(gi)): 1214*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("backward not multiplied by grad_output") 1215*da0073e9SAndroid Build Coastguard Worker elif not gi.eq(0).all(): 1216*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("backward not multiplied by grad_output") 1217*da0073e9SAndroid Build Coastguard Worker if gi.dtype != di.dtype: 1218*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("grad is incorrect type") 1219*da0073e9SAndroid Build Coastguard Worker if gi.device != di.device: 1220*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("grad is incorrect device") 1221*da0073e9SAndroid Build Coastguard Worker if gi.size() != di.size(): 1222*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("grad is incorrect size") 1223*da0073e9SAndroid Build Coastguard Worker return True 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Workerdef _test_undefined_forward_mode(func, outputs, inputs): 1227*da0073e9SAndroid Build Coastguard Worker fwAD = torch.autograd.forward_ad 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) 1230*da0073e9SAndroid Build Coastguard Worker all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True) 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 1235*da0073e9SAndroid Build Coastguard Worker fw_grads = [] 1236*da0073e9SAndroid Build Coastguard Worker dual_inputs = [] 1237*da0073e9SAndroid Build Coastguard Worker tensor_indices = set() 1238*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(inputs): 1239*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(inp) and inp.requires_grad: 1240*da0073e9SAndroid Build Coastguard Worker if inp.layout == torch._mkldnn: # type: ignore[attr-defined] 1241*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1242*da0073e9SAndroid Build Coastguard Worker "MKLDNN inputs are not support for forward AD gradcheck." 1243*da0073e9SAndroid Build Coastguard Worker ) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 1246*da0073e9SAndroid Build Coastguard Worker # If inp is a differentiable view, the dual might not be the tangent given to 1247*da0073e9SAndroid Build Coastguard Worker # make_dual, so read it explicitly from the dual tensor 1248*da0073e9SAndroid Build Coastguard Worker fw_grads.append(fwAD.unpack_dual(inp)[1]) 1249*da0073e9SAndroid Build Coastguard Worker tensor_indices.add(i) 1250*da0073e9SAndroid Build Coastguard Worker dual_inputs.append(inp) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): 1253*da0073e9SAndroid Build Coastguard Worker fw_grad.copy_(u.view_as(fw_grad)) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker for idx, inp in enumerate(inputs): 1256*da0073e9SAndroid Build Coastguard Worker if idx not in tensor_indices: 1257*da0073e9SAndroid Build Coastguard Worker continue 1258*da0073e9SAndroid Build Coastguard Worker dual_inp_obj = dual_inputs[idx] 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker # case 1 (Materialized Zero Tensor Tangent) 1261*da0073e9SAndroid Build Coastguard Worker dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 1262*da0073e9SAndroid Build Coastguard Worker raw_outputs = _as_tuple(func(*dual_inputs)) 1263*da0073e9SAndroid Build Coastguard Worker dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs) 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor) 1266*da0073e9SAndroid Build Coastguard Worker dual_inputs[idx] = inp.detach() 1267*da0073e9SAndroid Build Coastguard Worker raw_outputs = _as_tuple(func(*dual_inputs)) 1268*da0073e9SAndroid Build Coastguard Worker dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs) 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker # reset 1271*da0073e9SAndroid Build Coastguard Worker dual_inputs[idx] = dual_inp_obj 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)): 1274*da0073e9SAndroid Build Coastguard Worker val1, res1 = fwAD.unpack_dual(d_o1) 1275*da0073e9SAndroid Build Coastguard Worker val2, res2 = fwAD.unpack_dual(d_o2) 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker if not (res1 is None or res2 is None): 1278*da0073e9SAndroid Build Coastguard Worker if not torch.allclose(res1, res2): 1279*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1280*da0073e9SAndroid Build Coastguard Worker "Mismatch in tangent values for output with index: ", 1281*da0073e9SAndroid Build Coastguard Worker index_o, 1282*da0073e9SAndroid Build Coastguard Worker " when input: ", 1283*da0073e9SAndroid Build Coastguard Worker inp, 1284*da0073e9SAndroid Build Coastguard Worker " has an undefined tangent value. ", 1285*da0073e9SAndroid Build Coastguard Worker " Got: ", 1286*da0073e9SAndroid Build Coastguard Worker res1, 1287*da0073e9SAndroid Build Coastguard Worker " but expected: ", 1288*da0073e9SAndroid Build Coastguard Worker res2, 1289*da0073e9SAndroid Build Coastguard Worker ) 1290*da0073e9SAndroid Build Coastguard Worker return True 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Workerdef _test_undefined_backward_mode(func, outputs, inputs) -> bool: 1294*da0073e9SAndroid Build Coastguard Worker diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) 1295*da0073e9SAndroid Build Coastguard Worker if not diff_input_list: 1296*da0073e9SAndroid Build Coastguard Worker raise GradcheckError("no Tensors requiring grad found in input") 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def warn_bc_breaking(): 1299*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1300*da0073e9SAndroid Build Coastguard Worker "Backwards compatibility: New undefined gradient support checking " 1301*da0073e9SAndroid Build Coastguard Worker "feature is enabled by default, but it may break existing callers " 1302*da0073e9SAndroid Build Coastguard Worker "of this function. If this is true for you, you can call this " 1303*da0073e9SAndroid Build Coastguard Worker 'function with "check_undefined_grad=False" to disable the feature' 1304*da0073e9SAndroid Build Coastguard Worker ) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker def check_undefined_grad_support(output_to_check): 1307*da0073e9SAndroid Build Coastguard Worker grads_output = [ 1308*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) 1309*da0073e9SAndroid Build Coastguard Worker for o in output_to_check 1310*da0073e9SAndroid Build Coastguard Worker ] 1311*da0073e9SAndroid Build Coastguard Worker try: 1312*da0073e9SAndroid Build Coastguard Worker grads_input = torch.autograd.grad( 1313*da0073e9SAndroid Build Coastguard Worker output_to_check, diff_input_list, grads_output, allow_unused=True 1314*da0073e9SAndroid Build Coastguard Worker ) 1315*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1316*da0073e9SAndroid Build Coastguard Worker warn_bc_breaking() 1317*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1318*da0073e9SAndroid Build Coastguard Worker "Expected backward function to handle undefined output grads. " 1319*da0073e9SAndroid Build Coastguard Worker 'Please look at "Notes about undefined output gradients" in ' 1320*da0073e9SAndroid Build Coastguard Worker '"tools/autograd/derivatives.yaml"' 1321*da0073e9SAndroid Build Coastguard Worker ) from e 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker for gi, i in zip(grads_input, diff_input_list): 1324*da0073e9SAndroid Build Coastguard Worker if (gi is not None) and (not gi.eq(0).all()): 1325*da0073e9SAndroid Build Coastguard Worker warn_bc_breaking() 1326*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1327*da0073e9SAndroid Build Coastguard Worker "Expected all input grads to be undefined or zero when all output grads are undefined " 1328*da0073e9SAndroid Build Coastguard Worker 'or zero. Please look at "Notes about undefined output gradients" in ' 1329*da0073e9SAndroid Build Coastguard Worker '"tools/autograd/derivatives.yaml"' 1330*da0073e9SAndroid Build Coastguard Worker ) 1331*da0073e9SAndroid Build Coastguard Worker return True 1332*da0073e9SAndroid Build Coastguard Worker 1333*da0073e9SAndroid Build Coastguard Worker # All backward functions must work properly if all output grads are undefined 1334*da0073e9SAndroid Build Coastguard Worker outputs_to_check = [ 1335*da0073e9SAndroid Build Coastguard Worker [ 1336*da0073e9SAndroid Build Coastguard Worker torch._C._functions.UndefinedGrad()(o) 1337*da0073e9SAndroid Build Coastguard Worker for o in _differentiable_outputs(func(*inputs)) 1338*da0073e9SAndroid Build Coastguard Worker # This check filters out Tensor-likes that aren't instances of Tensor. 1339*da0073e9SAndroid Build Coastguard Worker if isinstance(o, torch.Tensor) 1340*da0073e9SAndroid Build Coastguard Worker ] 1341*da0073e9SAndroid Build Coastguard Worker ] 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker # If there are multiple output grads, we should be able to undef one at a time without error 1344*da0073e9SAndroid Build Coastguard Worker if len(outputs_to_check[0]) > 1: 1345*da0073e9SAndroid Build Coastguard Worker for undef_grad_idx in range(len(outputs)): 1346*da0073e9SAndroid Build Coastguard Worker output_to_check = _differentiable_outputs(func(*inputs)) 1347*da0073e9SAndroid Build Coastguard Worker outputs_to_check.append( 1348*da0073e9SAndroid Build Coastguard Worker [ 1349*da0073e9SAndroid Build Coastguard Worker torch._C._functions.UndefinedGrad()(o) 1350*da0073e9SAndroid Build Coastguard Worker if idx == undef_grad_idx 1351*da0073e9SAndroid Build Coastguard Worker else o 1352*da0073e9SAndroid Build Coastguard Worker for idx, o in enumerate(output_to_check) 1353*da0073e9SAndroid Build Coastguard Worker ] 1354*da0073e9SAndroid Build Coastguard Worker ) 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker return all(check_undefined_grad_support(output) for output in outputs_to_check) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(x): 1360*da0073e9SAndroid Build Coastguard Worker if isinstance(x, tuple): 1361*da0073e9SAndroid Build Coastguard Worker return x 1362*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, list): 1363*da0073e9SAndroid Build Coastguard Worker return tuple(x) 1364*da0073e9SAndroid Build Coastguard Worker else: 1365*da0073e9SAndroid Build Coastguard Worker return (x,) 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Workerdef _differentiable_outputs(x): 1369*da0073e9SAndroid Build Coastguard Worker return tuple(o for o in _as_tuple(x) if o.requires_grad) 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Workerdef _get_notallclose_msg( 1373*da0073e9SAndroid Build Coastguard Worker analytical, 1374*da0073e9SAndroid Build Coastguard Worker numerical, 1375*da0073e9SAndroid Build Coastguard Worker output_idx, 1376*da0073e9SAndroid Build Coastguard Worker input_idx, 1377*da0073e9SAndroid Build Coastguard Worker complex_indices, 1378*da0073e9SAndroid Build Coastguard Worker test_imag=False, 1379*da0073e9SAndroid Build Coastguard Worker is_forward_ad=False, 1380*da0073e9SAndroid Build Coastguard Worker) -> str: 1381*da0073e9SAndroid Build Coastguard Worker out_is_complex = ( 1382*da0073e9SAndroid Build Coastguard Worker (not is_forward_ad) and complex_indices and output_idx in complex_indices 1383*da0073e9SAndroid Build Coastguard Worker ) 1384*da0073e9SAndroid Build Coastguard Worker inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices 1385*da0073e9SAndroid Build Coastguard Worker part = "imaginary" if test_imag else "real" 1386*da0073e9SAndroid Build Coastguard Worker element = "inputs" if is_forward_ad else "outputs" 1387*da0073e9SAndroid Build Coastguard Worker prefix = ( 1388*da0073e9SAndroid Build Coastguard Worker "" 1389*da0073e9SAndroid Build Coastguard Worker if not (out_is_complex or inp_is_complex) 1390*da0073e9SAndroid Build Coastguard Worker else f"While considering the {part} part of complex {element} only, " 1391*da0073e9SAndroid Build Coastguard Worker ) 1392*da0073e9SAndroid Build Coastguard Worker mode = "computed with forward mode " if is_forward_ad else "" 1393*da0073e9SAndroid Build Coastguard Worker return ( 1394*da0073e9SAndroid Build Coastguard Worker prefix + "Jacobian %smismatch for output %d with respect to input %d,\n" 1395*da0073e9SAndroid Build Coastguard Worker "numerical:%s\nanalytical:%s\n" 1396*da0073e9SAndroid Build Coastguard Worker % (mode, output_idx, input_idx, numerical, analytical) 1397*da0073e9SAndroid Build Coastguard Worker ) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Workerdef _transpose(matrix_of_tensors): 1401*da0073e9SAndroid Build Coastguard Worker # returns list of tuples 1402*da0073e9SAndroid Build Coastguard Worker return list(zip(*matrix_of_tensors)) 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Workerdef _real_and_imag_output(fn): 1406*da0073e9SAndroid Build Coastguard Worker # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as 1407*da0073e9SAndroid Build Coastguard Worker # the original fn, except torch.real or torch.imag are applied to the complex outputs 1408*da0073e9SAndroid Build Coastguard Worker def apply_to_c_outs(fn, fn_to_apply): 1409*da0073e9SAndroid Build Coastguard Worker def wrapped_fn(*inputs): 1410*da0073e9SAndroid Build Coastguard Worker outs = _as_tuple(fn(*inputs)) 1411*da0073e9SAndroid Build Coastguard Worker return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker return wrapped_fn 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Workerdef _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): 1419*da0073e9SAndroid Build Coastguard Worker # returns new functions that take real inputs instead of complex inputs as 1420*da0073e9SAndroid Build Coastguard Worker # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j). 1421*da0073e9SAndroid Build Coastguard Worker # In each case, the other part is considered constant. 1422*da0073e9SAndroid Build Coastguard Worker # We do not use 0 for the constant here to make sure we always call the user function with a valid input. 1423*da0073e9SAndroid Build Coastguard Worker def apply_to_c_inps(fn, fn_to_apply): 1424*da0073e9SAndroid Build Coastguard Worker def wrapped_fn(*inputs): 1425*da0073e9SAndroid Build Coastguard Worker new_inputs = list(inputs) 1426*da0073e9SAndroid Build Coastguard Worker for should_be_complex in complex_inp_indices: 1427*da0073e9SAndroid Build Coastguard Worker new_inputs[should_be_complex] = fn_to_apply( 1428*da0073e9SAndroid Build Coastguard Worker new_inputs[should_be_complex], tupled_inputs[should_be_complex] 1429*da0073e9SAndroid Build Coastguard Worker ) 1430*da0073e9SAndroid Build Coastguard Worker return _as_tuple(fn(*new_inputs)) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker return wrapped_fn 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j) 1435*da0073e9SAndroid Build Coastguard Worker imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j) 1436*da0073e9SAndroid Build Coastguard Worker return real_fn, imag_fn 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Workerdef _gradcheck_real_imag( 1440*da0073e9SAndroid Build Coastguard Worker gradcheck_fn, 1441*da0073e9SAndroid Build Coastguard Worker func, 1442*da0073e9SAndroid Build Coastguard Worker func_out, 1443*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1444*da0073e9SAndroid Build Coastguard Worker outputs, 1445*da0073e9SAndroid Build Coastguard Worker eps, 1446*da0073e9SAndroid Build Coastguard Worker rtol, 1447*da0073e9SAndroid Build Coastguard Worker atol, 1448*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1449*da0073e9SAndroid Build Coastguard Worker check_forward_ad, 1450*da0073e9SAndroid Build Coastguard Worker check_backward_ad, 1451*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1452*da0073e9SAndroid Build Coastguard Worker check_undefined_grad, 1453*da0073e9SAndroid Build Coastguard Worker): 1454*da0073e9SAndroid Build Coastguard Worker complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()] 1455*da0073e9SAndroid Build Coastguard Worker has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out)) 1456*da0073e9SAndroid Build Coastguard Worker if check_backward_ad: 1457*da0073e9SAndroid Build Coastguard Worker if has_any_complex_output: 1458*da0073e9SAndroid Build Coastguard Worker real_fn, imag_fn = _real_and_imag_output(func) 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker imag_func_out = imag_fn(*tupled_inputs) 1461*da0073e9SAndroid Build Coastguard Worker imag_outputs = _differentiable_outputs(imag_func_out) 1462*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1463*da0073e9SAndroid Build Coastguard Worker imag_fn, 1464*da0073e9SAndroid Build Coastguard Worker imag_func_out, 1465*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1466*da0073e9SAndroid Build Coastguard Worker imag_outputs, 1467*da0073e9SAndroid Build Coastguard Worker eps, 1468*da0073e9SAndroid Build Coastguard Worker rtol, 1469*da0073e9SAndroid Build Coastguard Worker atol, 1470*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1471*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1472*da0073e9SAndroid Build Coastguard Worker complex_indices=complex_out_indices, 1473*da0073e9SAndroid Build Coastguard Worker test_imag=True, 1474*da0073e9SAndroid Build Coastguard Worker ) 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker real_func_out = real_fn(*tupled_inputs) 1477*da0073e9SAndroid Build Coastguard Worker real_outputs = _differentiable_outputs(real_func_out) 1478*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1479*da0073e9SAndroid Build Coastguard Worker real_fn, 1480*da0073e9SAndroid Build Coastguard Worker real_func_out, 1481*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1482*da0073e9SAndroid Build Coastguard Worker real_outputs, 1483*da0073e9SAndroid Build Coastguard Worker eps, 1484*da0073e9SAndroid Build Coastguard Worker rtol, 1485*da0073e9SAndroid Build Coastguard Worker atol, 1486*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1487*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1488*da0073e9SAndroid Build Coastguard Worker complex_indices=complex_out_indices, 1489*da0073e9SAndroid Build Coastguard Worker ) 1490*da0073e9SAndroid Build Coastguard Worker else: 1491*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1492*da0073e9SAndroid Build Coastguard Worker func, 1493*da0073e9SAndroid Build Coastguard Worker func_out, 1494*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1495*da0073e9SAndroid Build Coastguard Worker outputs, 1496*da0073e9SAndroid Build Coastguard Worker eps, 1497*da0073e9SAndroid Build Coastguard Worker rtol, 1498*da0073e9SAndroid Build Coastguard Worker atol, 1499*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1500*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1501*da0073e9SAndroid Build Coastguard Worker ) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker if check_forward_ad: 1504*da0073e9SAndroid Build Coastguard Worker complex_inp_indices = [ 1505*da0073e9SAndroid Build Coastguard Worker i 1506*da0073e9SAndroid Build Coastguard Worker for i, inp in enumerate(tupled_inputs) 1507*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(inp) and inp.is_complex() 1508*da0073e9SAndroid Build Coastguard Worker ] 1509*da0073e9SAndroid Build Coastguard Worker if complex_inp_indices: 1510*da0073e9SAndroid Build Coastguard Worker real_fn, imag_fn = _real_and_imag_input( 1511*da0073e9SAndroid Build Coastguard Worker func, complex_inp_indices, tupled_inputs 1512*da0073e9SAndroid Build Coastguard Worker ) 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker imag_inputs = [ 1515*da0073e9SAndroid Build Coastguard Worker inp.imag if is_tensor_like(inp) and inp.is_complex() else inp 1516*da0073e9SAndroid Build Coastguard Worker for inp in tupled_inputs 1517*da0073e9SAndroid Build Coastguard Worker ] 1518*da0073e9SAndroid Build Coastguard Worker imag_func_out = imag_fn(*imag_inputs) 1519*da0073e9SAndroid Build Coastguard Worker diff_imag_func_out = _differentiable_outputs(imag_func_out) 1520*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1521*da0073e9SAndroid Build Coastguard Worker imag_fn, 1522*da0073e9SAndroid Build Coastguard Worker imag_func_out, 1523*da0073e9SAndroid Build Coastguard Worker imag_inputs, 1524*da0073e9SAndroid Build Coastguard Worker diff_imag_func_out, 1525*da0073e9SAndroid Build Coastguard Worker eps, 1526*da0073e9SAndroid Build Coastguard Worker rtol, 1527*da0073e9SAndroid Build Coastguard Worker atol, 1528*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1529*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1530*da0073e9SAndroid Build Coastguard Worker complex_indices=complex_inp_indices, 1531*da0073e9SAndroid Build Coastguard Worker test_imag=True, 1532*da0073e9SAndroid Build Coastguard Worker use_forward_ad=True, 1533*da0073e9SAndroid Build Coastguard Worker ) 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker real_inputs = [ 1536*da0073e9SAndroid Build Coastguard Worker inp.real if is_tensor_like(inp) and inp.is_complex() else inp 1537*da0073e9SAndroid Build Coastguard Worker for inp in tupled_inputs 1538*da0073e9SAndroid Build Coastguard Worker ] 1539*da0073e9SAndroid Build Coastguard Worker real_func_out = real_fn(*real_inputs) 1540*da0073e9SAndroid Build Coastguard Worker diff_real_func_out = _differentiable_outputs(real_func_out) 1541*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1542*da0073e9SAndroid Build Coastguard Worker real_fn, 1543*da0073e9SAndroid Build Coastguard Worker real_func_out, 1544*da0073e9SAndroid Build Coastguard Worker real_inputs, 1545*da0073e9SAndroid Build Coastguard Worker diff_real_func_out, 1546*da0073e9SAndroid Build Coastguard Worker eps, 1547*da0073e9SAndroid Build Coastguard Worker rtol, 1548*da0073e9SAndroid Build Coastguard Worker atol, 1549*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1550*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1551*da0073e9SAndroid Build Coastguard Worker complex_indices=complex_inp_indices, 1552*da0073e9SAndroid Build Coastguard Worker use_forward_ad=True, 1553*da0073e9SAndroid Build Coastguard Worker ) 1554*da0073e9SAndroid Build Coastguard Worker if check_undefined_grad: 1555*da0073e9SAndroid Build Coastguard Worker _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs) 1556*da0073e9SAndroid Build Coastguard Worker _test_undefined_forward_mode(real_fn, real_func_out, real_inputs) 1557*da0073e9SAndroid Build Coastguard Worker else: 1558*da0073e9SAndroid Build Coastguard Worker gradcheck_fn( 1559*da0073e9SAndroid Build Coastguard Worker func, 1560*da0073e9SAndroid Build Coastguard Worker func_out, 1561*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1562*da0073e9SAndroid Build Coastguard Worker outputs, 1563*da0073e9SAndroid Build Coastguard Worker eps, 1564*da0073e9SAndroid Build Coastguard Worker rtol, 1565*da0073e9SAndroid Build Coastguard Worker atol, 1566*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1567*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1568*da0073e9SAndroid Build Coastguard Worker use_forward_ad=True, 1569*da0073e9SAndroid Build Coastguard Worker ) 1570*da0073e9SAndroid Build Coastguard Worker if check_undefined_grad: 1571*da0073e9SAndroid Build Coastguard Worker _test_undefined_forward_mode(func, outputs, tupled_inputs) 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Worker 1574*da0073e9SAndroid Build Coastguard Workerdef _slow_gradcheck( 1575*da0073e9SAndroid Build Coastguard Worker func, 1576*da0073e9SAndroid Build Coastguard Worker func_out, 1577*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1578*da0073e9SAndroid Build Coastguard Worker outputs, 1579*da0073e9SAndroid Build Coastguard Worker eps, 1580*da0073e9SAndroid Build Coastguard Worker rtol, 1581*da0073e9SAndroid Build Coastguard Worker atol, 1582*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1583*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1584*da0073e9SAndroid Build Coastguard Worker *, 1585*da0073e9SAndroid Build Coastguard Worker use_forward_ad=False, 1586*da0073e9SAndroid Build Coastguard Worker complex_indices=None, 1587*da0073e9SAndroid Build Coastguard Worker test_imag=False, 1588*da0073e9SAndroid Build Coastguard Worker masked=False, 1589*da0073e9SAndroid Build Coastguard Worker): 1590*da0073e9SAndroid Build Coastguard Worker func_out = _as_tuple(func_out) 1591*da0073e9SAndroid Build Coastguard Worker if not outputs: 1592*da0073e9SAndroid Build Coastguard Worker return _check_no_differentiable_outputs( 1593*da0073e9SAndroid Build Coastguard Worker func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad 1594*da0073e9SAndroid Build Coastguard Worker ) 1595*da0073e9SAndroid Build Coastguard Worker tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker numerical = _transpose( 1598*da0073e9SAndroid Build Coastguard Worker _get_numerical_jacobian( 1599*da0073e9SAndroid Build Coastguard Worker func, 1600*da0073e9SAndroid Build Coastguard Worker tupled_inputs_numerical, 1601*da0073e9SAndroid Build Coastguard Worker func_out, 1602*da0073e9SAndroid Build Coastguard Worker eps=eps, 1603*da0073e9SAndroid Build Coastguard Worker is_forward_ad=use_forward_ad, 1604*da0073e9SAndroid Build Coastguard Worker ) 1605*da0073e9SAndroid Build Coastguard Worker ) 1606*da0073e9SAndroid Build Coastguard Worker # Note: [numerical vs analytical output length] 1607*da0073e9SAndroid Build Coastguard Worker # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that 1608*da0073e9SAndroid Build Coastguard Worker # output is False. This behavior is necessary for _check_no_differentiable_outputs to work. 1609*da0073e9SAndroid Build Coastguard Worker numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad] 1610*da0073e9SAndroid Build Coastguard Worker if use_forward_ad: 1611*da0073e9SAndroid Build Coastguard Worker analytical_forward = _get_analytical_jacobian_forward_ad( 1612*da0073e9SAndroid Build Coastguard Worker func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes 1613*da0073e9SAndroid Build Coastguard Worker ) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker for i, n_per_out in enumerate(numerical): 1616*da0073e9SAndroid Build Coastguard Worker for j, n in enumerate(n_per_out): 1617*da0073e9SAndroid Build Coastguard Worker a = analytical_forward[j][i] 1618*da0073e9SAndroid Build Coastguard Worker if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): 1619*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1620*da0073e9SAndroid Build Coastguard Worker _get_notallclose_msg( 1621*da0073e9SAndroid Build Coastguard Worker a, n, i, j, complex_indices, test_imag, is_forward_ad=True 1622*da0073e9SAndroid Build Coastguard Worker ) 1623*da0073e9SAndroid Build Coastguard Worker ) 1624*da0073e9SAndroid Build Coastguard Worker else: 1625*da0073e9SAndroid Build Coastguard Worker for i, o in enumerate(outputs): 1626*da0073e9SAndroid Build Coastguard Worker analytical = _check_analytical_jacobian_attributes( 1627*da0073e9SAndroid Build Coastguard Worker tupled_inputs, o, nondet_tol, check_grad_dtypes 1628*da0073e9SAndroid Build Coastguard Worker ) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker for j, (a, n) in enumerate(zip(analytical, numerical[i])): 1631*da0073e9SAndroid Build Coastguard Worker if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): 1632*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1633*da0073e9SAndroid Build Coastguard Worker _get_notallclose_msg(a, n, i, j, complex_indices, test_imag) 1634*da0073e9SAndroid Build Coastguard Worker ) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker return True 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Workerdef _dot_with_type_promotion(u, v): 1640*da0073e9SAndroid Build Coastguard Worker assert u.dim() == 1 and v.dim() == 1 1641*da0073e9SAndroid Build Coastguard Worker return (u * v).sum() 1642*da0073e9SAndroid Build Coastguard Worker 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Workerdef _allclose_with_type_promotion(a, b, rtol, atol): 1645*da0073e9SAndroid Build Coastguard Worker promoted_type = torch.promote_types(a.dtype, b.dtype) 1646*da0073e9SAndroid Build Coastguard Worker a = a.to(dtype=promoted_type) 1647*da0073e9SAndroid Build Coastguard Worker b = b.to(dtype=promoted_type) 1648*da0073e9SAndroid Build Coastguard Worker return torch.allclose(a, b, rtol, atol) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker 1651*da0073e9SAndroid Build Coastguard Workerdef _to_real_dtype(dtype): 1652*da0073e9SAndroid Build Coastguard Worker if dtype == torch.complex128: 1653*da0073e9SAndroid Build Coastguard Worker return torch.float64 1654*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.complex64: 1655*da0073e9SAndroid Build Coastguard Worker return torch.float32 1656*da0073e9SAndroid Build Coastguard Worker else: 1657*da0073e9SAndroid Build Coastguard Worker return dtype 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker 1660*da0073e9SAndroid Build Coastguard Workerdef _vec_from_tensor(x, generator, downcast_complex=False): 1661*da0073e9SAndroid Build Coastguard Worker # Create a random vector with the same number of elements as x and the same 1662*da0073e9SAndroid Build Coastguard Worker # dtype/device. If x is complex and downcast_complex is False, we create a 1663*da0073e9SAndroid Build Coastguard Worker # complex tensor with only real component. 1664*da0073e9SAndroid Build Coastguard Worker if x.layout == torch.sparse_coo: 1665*da0073e9SAndroid Build Coastguard Worker # For sparse, create a random sparse vec with random values in the same 1666*da0073e9SAndroid Build Coastguard Worker # indices. Make sure size is set so that it isn't inferred to be smaller. 1667*da0073e9SAndroid Build Coastguard Worker x_values = x._values() 1668*da0073e9SAndroid Build Coastguard Worker dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1669*da0073e9SAndroid Build Coastguard Worker values = ( 1670*da0073e9SAndroid Build Coastguard Worker torch.rand(x_values.numel(), generator=generator) 1671*da0073e9SAndroid Build Coastguard Worker .to(dtype=dtype, device=x.device) 1672*da0073e9SAndroid Build Coastguard Worker .view(x_values.shape) 1673*da0073e9SAndroid Build Coastguard Worker ) 1674*da0073e9SAndroid Build Coastguard Worker values /= values.norm() 1675*da0073e9SAndroid Build Coastguard Worker vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device) 1676*da0073e9SAndroid Build Coastguard Worker elif _is_sparse_compressed_tensor(x): 1677*da0073e9SAndroid Build Coastguard Worker if x.layout in {torch.sparse_csr, torch.sparse_bsr}: 1678*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = x.crow_indices(), x.col_indices() 1679*da0073e9SAndroid Build Coastguard Worker else: 1680*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() 1681*da0073e9SAndroid Build Coastguard Worker x_values = x.values() 1682*da0073e9SAndroid Build Coastguard Worker dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1683*da0073e9SAndroid Build Coastguard Worker values = ( 1684*da0073e9SAndroid Build Coastguard Worker torch.rand(x_values.numel(), generator=generator) 1685*da0073e9SAndroid Build Coastguard Worker .to(dtype=dtype, device=x.device) 1686*da0073e9SAndroid Build Coastguard Worker .view(x_values.shape) 1687*da0073e9SAndroid Build Coastguard Worker ) 1688*da0073e9SAndroid Build Coastguard Worker values /= values.norm() 1689*da0073e9SAndroid Build Coastguard Worker vec = torch.sparse_compressed_tensor( 1690*da0073e9SAndroid Build Coastguard Worker compressed_indices, 1691*da0073e9SAndroid Build Coastguard Worker plain_indices, 1692*da0073e9SAndroid Build Coastguard Worker values, 1693*da0073e9SAndroid Build Coastguard Worker x.size(), 1694*da0073e9SAndroid Build Coastguard Worker layout=x.layout, 1695*da0073e9SAndroid Build Coastguard Worker device=x.device, 1696*da0073e9SAndroid Build Coastguard Worker ) 1697*da0073e9SAndroid Build Coastguard Worker else: 1698*da0073e9SAndroid Build Coastguard Worker dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1699*da0073e9SAndroid Build Coastguard Worker vec = torch.rand(x.numel(), generator=generator).to( 1700*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=x.device 1701*da0073e9SAndroid Build Coastguard Worker ) 1702*da0073e9SAndroid Build Coastguard Worker vec /= vec.norm() 1703*da0073e9SAndroid Build Coastguard Worker return vec 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Workerdef _get_inp_tensors(tupled_inputs): 1707*da0073e9SAndroid Build Coastguard Worker inp_idx_tup = [ 1708*da0073e9SAndroid Build Coastguard Worker (i, t) 1709*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(tupled_inputs) 1710*da0073e9SAndroid Build Coastguard Worker if is_tensor_like(t) and t.requires_grad 1711*da0073e9SAndroid Build Coastguard Worker ] 1712*da0073e9SAndroid Build Coastguard Worker return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup] 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Workerdef _adjusted_atol(atol, u, v): 1716*da0073e9SAndroid Build Coastguard Worker # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we 1717*da0073e9SAndroid Build Coastguard Worker # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and 1718*da0073e9SAndroid Build Coastguard Worker # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is 1719*da0073e9SAndroid Build Coastguard Worker # the correctly sized matrix in which each entry is atol. 1720*da0073e9SAndroid Build Coastguard Worker # 1721*da0073e9SAndroid Build Coastguard Worker # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N 1722*da0073e9SAndroid Build Coastguard Worker # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) 1723*da0073e9SAndroid Build Coastguard Worker # TODO: properly handle case when u is tuple instead of only taking first element 1724*da0073e9SAndroid Build Coastguard Worker u = u[0] if isinstance(u, tuple) else u 1725*da0073e9SAndroid Build Coastguard Worker sum_u = u.sum() 1726*da0073e9SAndroid Build Coastguard Worker sum_v = 1.0 if v is None else v.sum() 1727*da0073e9SAndroid Build Coastguard Worker return atol * float(sum_u) * float(sum_v) 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard WorkerFAST_FAIL_SLOW_OK_MSG = """ 1731*da0073e9SAndroid Build Coastguard WorkerFast gradcheck failed but element-wise differences are small. This means that the 1732*da0073e9SAndroid Build Coastguard Workertest might've passed in slow_mode! 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard WorkerIf you are adding a new operator, please file an issue and then use one of the 1735*da0073e9SAndroid Build Coastguard Workerworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck: 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard WorkerIf the test 1738*da0073e9SAndroid Build Coastguard Worker- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1739*da0073e9SAndroid Build Coastguard Worker with `fast_mode=False` as a keyword argument. 1740*da0073e9SAndroid Build Coastguard Worker- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1741*da0073e9SAndroid Build Coastguard Worker to have `gradcheck_fast_mode=False` 1742*da0073e9SAndroid Build Coastguard Worker- is a Module test (e.g., in common_nn.py), then modify the corresponding 1743*da0073e9SAndroid Build Coastguard Worker module_test entry to have `gradcheck_fast_mode=False` 1744*da0073e9SAndroid Build Coastguard Worker""".strip() 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Workerdef _run_slow_mode_and_get_error( 1748*da0073e9SAndroid Build Coastguard Worker func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad 1749*da0073e9SAndroid Build Coastguard Worker): 1750*da0073e9SAndroid Build Coastguard Worker # Compute jacobians in slow mode for better error message 1751*da0073e9SAndroid Build Coastguard Worker slow_numerical = _get_numerical_jacobian( 1752*da0073e9SAndroid Build Coastguard Worker func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad 1753*da0073e9SAndroid Build Coastguard Worker )[input_idx][output_idx] 1754*da0073e9SAndroid Build Coastguard Worker if is_forward_ad: 1755*da0073e9SAndroid Build Coastguard Worker 1756*da0073e9SAndroid Build Coastguard Worker def new_fn(inp): 1757*da0073e9SAndroid Build Coastguard Worker new_inputs = list(tupled_inputs) 1758*da0073e9SAndroid Build Coastguard Worker new_inputs[input_idx] = inp 1759*da0073e9SAndroid Build Coastguard Worker return _as_tuple(func(*new_inputs))[output_idx] 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker slow_analytical = _get_analytical_jacobian_forward_ad( 1762*da0073e9SAndroid Build Coastguard Worker new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],) 1763*da0073e9SAndroid Build Coastguard Worker )[0][0] 1764*da0073e9SAndroid Build Coastguard Worker else: 1765*da0073e9SAndroid Build Coastguard Worker slow_analytical = _get_analytical_jacobian( 1766*da0073e9SAndroid Build Coastguard Worker tupled_inputs, outputs, input_idx, output_idx 1767*da0073e9SAndroid Build Coastguard Worker ) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker # Assume jacobians are non-empty and have the same shape 1770*da0073e9SAndroid Build Coastguard Worker slow_max_diff = (slow_numerical - slow_analytical).abs().max() 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol) 1773*da0073e9SAndroid Build Coastguard Worker msg = ( 1774*da0073e9SAndroid Build Coastguard Worker "\nThe above quantities relating the numerical and analytical jacobians are computed \n" 1775*da0073e9SAndroid Build Coastguard Worker "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" 1776*da0073e9SAndroid Build Coastguard Worker "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" 1777*da0073e9SAndroid Build Coastguard Worker f"Numerical:\n {slow_numerical}\n" 1778*da0073e9SAndroid Build Coastguard Worker f"Analytical:\n{slow_analytical}\n\n" 1779*da0073e9SAndroid Build Coastguard Worker f"The max per-element difference (slow mode) is: {slow_max_diff}.\n" 1780*da0073e9SAndroid Build Coastguard Worker ) 1781*da0073e9SAndroid Build Coastguard Worker if slow_allclose: 1782*da0073e9SAndroid Build Coastguard Worker # Slow gradcheck would've passed! 1783*da0073e9SAndroid Build Coastguard Worker msg += FAST_FAIL_SLOW_OK_MSG 1784*da0073e9SAndroid Build Coastguard Worker return msg 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Workerdef _to_flat_dense_if_sparse(tensor): 1788*da0073e9SAndroid Build Coastguard Worker if _is_sparse_any_tensor(tensor): 1789*da0073e9SAndroid Build Coastguard Worker return tensor.to_dense().reshape(-1) 1790*da0073e9SAndroid Build Coastguard Worker else: 1791*da0073e9SAndroid Build Coastguard Worker return tensor 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Workerdef _make_vectors(inp_tensors, outputs, *, use_forward_ad): 1795*da0073e9SAndroid Build Coastguard Worker # Use our own generator to avoid messing with the user's RNG state 1796*da0073e9SAndroid Build Coastguard Worker g_cpu = torch.Generator() 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker def _vec_from_tensor_cpu(*args): 1799*da0073e9SAndroid Build Coastguard Worker # Default allocate all tensors on CPU, so they are on the same device as the generator 1800*da0073e9SAndroid Build Coastguard Worker # even if the user specified a default device 1801*da0073e9SAndroid Build Coastguard Worker with torch.device("cpu"): 1802*da0073e9SAndroid Build Coastguard Worker return _vec_from_tensor(*args) 1803*da0073e9SAndroid Build Coastguard Worker 1804*da0073e9SAndroid Build Coastguard Worker all_u = [] 1805*da0073e9SAndroid Build Coastguard Worker all_u_dense = [] 1806*da0073e9SAndroid Build Coastguard Worker for inp in inp_tensors: 1807*da0073e9SAndroid Build Coastguard Worker ur = _vec_from_tensor_cpu(inp, g_cpu, True) 1808*da0073e9SAndroid Build Coastguard Worker ur_dense = _to_flat_dense_if_sparse(ur) 1809*da0073e9SAndroid Build Coastguard Worker if inp.is_complex(): 1810*da0073e9SAndroid Build Coastguard Worker ui = _vec_from_tensor_cpu(inp, g_cpu, True) 1811*da0073e9SAndroid Build Coastguard Worker all_u.append((ur, ui)) 1812*da0073e9SAndroid Build Coastguard Worker ui_dense = _to_flat_dense_if_sparse(ui) 1813*da0073e9SAndroid Build Coastguard Worker all_u_dense.append((ur_dense, ui_dense)) 1814*da0073e9SAndroid Build Coastguard Worker else: 1815*da0073e9SAndroid Build Coastguard Worker all_u.append(ur) 1816*da0073e9SAndroid Build Coastguard Worker all_u_dense.append(ur_dense) 1817*da0073e9SAndroid Build Coastguard Worker all_v = ( 1818*da0073e9SAndroid Build Coastguard Worker None 1819*da0073e9SAndroid Build Coastguard Worker if use_forward_ad 1820*da0073e9SAndroid Build Coastguard Worker else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs] 1821*da0073e9SAndroid Build Coastguard Worker ) 1822*da0073e9SAndroid Build Coastguard Worker return all_v, all_u, all_u_dense 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Workerdef _check_analytical_numerical_equal( 1826*da0073e9SAndroid Build Coastguard Worker all_analytical, 1827*da0073e9SAndroid Build Coastguard Worker all_numerical, 1828*da0073e9SAndroid Build Coastguard Worker complex_indices, 1829*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 1830*da0073e9SAndroid Build Coastguard Worker outputs, 1831*da0073e9SAndroid Build Coastguard Worker func, 1832*da0073e9SAndroid Build Coastguard Worker all_v, 1833*da0073e9SAndroid Build Coastguard Worker all_u, 1834*da0073e9SAndroid Build Coastguard Worker rtol, 1835*da0073e9SAndroid Build Coastguard Worker atol, 1836*da0073e9SAndroid Build Coastguard Worker eps, 1837*da0073e9SAndroid Build Coastguard Worker test_imag, 1838*da0073e9SAndroid Build Coastguard Worker *, 1839*da0073e9SAndroid Build Coastguard Worker is_forward_ad=False, 1840*da0073e9SAndroid Build Coastguard Worker): 1841*da0073e9SAndroid Build Coastguard Worker for i, all_numerical_for_input_i in enumerate(all_numerical): 1842*da0073e9SAndroid Build Coastguard Worker for j, n in enumerate(all_numerical_for_input_i): 1843*da0073e9SAndroid Build Coastguard Worker # Forward AD generates the transpose of what this function expects 1844*da0073e9SAndroid Build Coastguard Worker if is_forward_ad: 1845*da0073e9SAndroid Build Coastguard Worker a = all_analytical[i][j] 1846*da0073e9SAndroid Build Coastguard Worker else: 1847*da0073e9SAndroid Build Coastguard Worker a = all_analytical[j][i] 1848*da0073e9SAndroid Build Coastguard Worker n = n.to(device=a.device) 1849*da0073e9SAndroid Build Coastguard Worker updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None) 1850*da0073e9SAndroid Build Coastguard Worker if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol): 1851*da0073e9SAndroid Build Coastguard Worker jacobians_str = _run_slow_mode_and_get_error( 1852*da0073e9SAndroid Build Coastguard Worker func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad 1853*da0073e9SAndroid Build Coastguard Worker ) 1854*da0073e9SAndroid Build Coastguard Worker raise GradcheckError( 1855*da0073e9SAndroid Build Coastguard Worker _get_notallclose_msg( 1856*da0073e9SAndroid Build Coastguard Worker a, n, j, i, complex_indices, test_imag, is_forward_ad 1857*da0073e9SAndroid Build Coastguard Worker ) 1858*da0073e9SAndroid Build Coastguard Worker + jacobians_str 1859*da0073e9SAndroid Build Coastguard Worker ) 1860*da0073e9SAndroid Build Coastguard Worker 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Workerdef _fast_gradcheck( 1863*da0073e9SAndroid Build Coastguard Worker func, 1864*da0073e9SAndroid Build Coastguard Worker func_out, 1865*da0073e9SAndroid Build Coastguard Worker inputs, 1866*da0073e9SAndroid Build Coastguard Worker outputs, 1867*da0073e9SAndroid Build Coastguard Worker eps, 1868*da0073e9SAndroid Build Coastguard Worker rtol, 1869*da0073e9SAndroid Build Coastguard Worker atol, 1870*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 1871*da0073e9SAndroid Build Coastguard Worker nondet_tol, 1872*da0073e9SAndroid Build Coastguard Worker *, 1873*da0073e9SAndroid Build Coastguard Worker use_forward_ad=False, 1874*da0073e9SAndroid Build Coastguard Worker complex_indices=None, 1875*da0073e9SAndroid Build Coastguard Worker test_imag=False, 1876*da0073e9SAndroid Build Coastguard Worker masked=False, 1877*da0073e9SAndroid Build Coastguard Worker): 1878*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/53876 for details 1879*da0073e9SAndroid Build Coastguard Worker inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) 1880*da0073e9SAndroid Build Coastguard Worker # Backward mode computes v^T * J (VJP) 1881*da0073e9SAndroid Build Coastguard Worker # Since we computed J * u (JVP) through finite difference method, we perform an equality check 1882*da0073e9SAndroid Build Coastguard Worker # between VJP * u, v * JVP 1883*da0073e9SAndroid Build Coastguard Worker # ---- 1884*da0073e9SAndroid Build Coastguard Worker # Forward mode computes J * u (JVP) 1885*da0073e9SAndroid Build Coastguard Worker # Since we already compute JVP through finite difference method, 1886*da0073e9SAndroid Build Coastguard Worker # we don't need v for correctness check here as asserted below 1887*da0073e9SAndroid Build Coastguard Worker all_v, all_u, all_u_dense = _make_vectors( 1888*da0073e9SAndroid Build Coastguard Worker inp_tensors, outputs, use_forward_ad=use_forward_ad 1889*da0073e9SAndroid Build Coastguard Worker ) 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Worker inputs_numerical, all_u_numerical, all_v_numerical = ( 1892*da0073e9SAndroid Build Coastguard Worker (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) 1893*da0073e9SAndroid Build Coastguard Worker ) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker numerical_vJu = _get_numerical_vJu( 1896*da0073e9SAndroid Build Coastguard Worker func, 1897*da0073e9SAndroid Build Coastguard Worker inputs_numerical, 1898*da0073e9SAndroid Build Coastguard Worker inp_tensors_idx, 1899*da0073e9SAndroid Build Coastguard Worker func_out, 1900*da0073e9SAndroid Build Coastguard Worker all_u_numerical, 1901*da0073e9SAndroid Build Coastguard Worker all_v_numerical, 1902*da0073e9SAndroid Build Coastguard Worker eps, 1903*da0073e9SAndroid Build Coastguard Worker is_forward_ad=use_forward_ad, 1904*da0073e9SAndroid Build Coastguard Worker ) 1905*da0073e9SAndroid Build Coastguard Worker # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well 1906*da0073e9SAndroid Build Coastguard Worker if use_forward_ad: 1907*da0073e9SAndroid Build Coastguard Worker assert all_v is None 1908*da0073e9SAndroid Build Coastguard Worker analytical_vJu = _get_analytical_jacobian_forward_ad( 1909*da0073e9SAndroid Build Coastguard Worker func, 1910*da0073e9SAndroid Build Coastguard Worker inputs, 1911*da0073e9SAndroid Build Coastguard Worker _as_tuple(func_out), 1912*da0073e9SAndroid Build Coastguard Worker all_u=all_u, 1913*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes=check_grad_dtypes, 1914*da0073e9SAndroid Build Coastguard Worker ) 1915*da0073e9SAndroid Build Coastguard Worker else: 1916*da0073e9SAndroid Build Coastguard Worker if not outputs: 1917*da0073e9SAndroid Build Coastguard Worker _check_no_differentiable_outputs_fast( 1918*da0073e9SAndroid Build Coastguard Worker func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol 1919*da0073e9SAndroid Build Coastguard Worker ) 1920*da0073e9SAndroid Build Coastguard Worker 1921*da0073e9SAndroid Build Coastguard Worker analytical_vJu = _get_analytical_vJu_backward_mode( 1922*da0073e9SAndroid Build Coastguard Worker inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense 1923*da0073e9SAndroid Build Coastguard Worker ) 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker _check_analytical_numerical_equal( 1926*da0073e9SAndroid Build Coastguard Worker analytical_vJu, 1927*da0073e9SAndroid Build Coastguard Worker numerical_vJu, 1928*da0073e9SAndroid Build Coastguard Worker complex_indices, 1929*da0073e9SAndroid Build Coastguard Worker inputs, 1930*da0073e9SAndroid Build Coastguard Worker outputs, 1931*da0073e9SAndroid Build Coastguard Worker func, 1932*da0073e9SAndroid Build Coastguard Worker all_v, 1933*da0073e9SAndroid Build Coastguard Worker all_u, 1934*da0073e9SAndroid Build Coastguard Worker rtol, 1935*da0073e9SAndroid Build Coastguard Worker atol, 1936*da0073e9SAndroid Build Coastguard Worker eps, 1937*da0073e9SAndroid Build Coastguard Worker test_imag, 1938*da0073e9SAndroid Build Coastguard Worker is_forward_ad=use_forward_ad, 1939*da0073e9SAndroid Build Coastguard Worker ) 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker return True 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker# Note [VarArg of Tensors] 1945*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~ 1946*da0073e9SAndroid Build Coastguard Worker# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. 1947*da0073e9SAndroid Build Coastguard Worker# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, 1948*da0073e9SAndroid Build Coastguard Worker# the '...' first argument of Callable can be replaced with VarArg(Tensor). 1949*da0073e9SAndroid Build Coastguard Worker# For now, we permit any input. 1950*da0073e9SAndroid Build Coastguard Workerdef gradcheck( 1951*da0073e9SAndroid Build Coastguard Worker func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors] 1952*da0073e9SAndroid Build Coastguard Worker inputs: _TensorOrTensors, 1953*da0073e9SAndroid Build Coastguard Worker *, 1954*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-6, 1955*da0073e9SAndroid Build Coastguard Worker atol: float = 1e-5, 1956*da0073e9SAndroid Build Coastguard Worker rtol: float = 1e-3, 1957*da0073e9SAndroid Build Coastguard Worker raise_exception: bool = True, 1958*da0073e9SAndroid Build Coastguard Worker nondet_tol: float = 0.0, 1959*da0073e9SAndroid Build Coastguard Worker check_undefined_grad: bool = True, 1960*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes: bool = False, 1961*da0073e9SAndroid Build Coastguard Worker check_batched_grad: bool = False, 1962*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad: bool = False, 1963*da0073e9SAndroid Build Coastguard Worker check_forward_ad: bool = False, 1964*da0073e9SAndroid Build Coastguard Worker check_backward_ad: bool = True, 1965*da0073e9SAndroid Build Coastguard Worker fast_mode: bool = False, 1966*da0073e9SAndroid Build Coastguard Worker masked: Optional[bool] = None, 1967*da0073e9SAndroid Build Coastguard Worker) -> bool: # noqa: D400,D205 1968*da0073e9SAndroid Build Coastguard Worker r"""Check gradients computed via small finite differences against analytical 1969*da0073e9SAndroid Build Coastguard Worker gradients wrt tensors in :attr:`inputs` that are of floating point or complex type 1970*da0073e9SAndroid Build Coastguard Worker and with ``requires_grad=True``. 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker The check between numerical and analytical gradients uses :func:`~torch.allclose`. 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Worker For most of the complex functions we consider for optimization purposes, no notion of 1975*da0073e9SAndroid Build Coastguard Worker Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of 1976*da0073e9SAndroid Build Coastguard Worker the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient 1977*da0073e9SAndroid Build Coastguard Worker computation is done under the assumption that the overall function has a real-valued 1978*da0073e9SAndroid Build Coastguard Worker output, we treat functions with complex output in a special way. For these functions, 1979*da0073e9SAndroid Build Coastguard Worker gradcheck is applied to two real-valued functions corresponding to taking the real 1980*da0073e9SAndroid Build Coastguard Worker components of the complex outputs for the first, and taking the imaginary components 1981*da0073e9SAndroid Build Coastguard Worker of the complex outputs for the second. For more details, check out 1982*da0073e9SAndroid Build Coastguard Worker :ref:`complex_autograd-doc`. 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker .. note:: 1985*da0073e9SAndroid Build Coastguard Worker The default values are designed for :attr:`input` of double precision. 1986*da0073e9SAndroid Build Coastguard Worker This check will likely fail if :attr:`input` is of less precision, e.g., 1987*da0073e9SAndroid Build Coastguard Worker ``FloatTensor``. 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker .. note:: 1990*da0073e9SAndroid Build Coastguard Worker Gradcheck may fail when evaluated on non-differentiable points 1991*da0073e9SAndroid Build Coastguard Worker because the numerically computed gradients via finite differencing may differ 1992*da0073e9SAndroid Build Coastguard Worker those computed analytically (not necessarily because either is incorrect). 1993*da0073e9SAndroid Build Coastguard Worker For more context, see :ref:`non-differentiable-func-grad`. 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker .. warning:: 1996*da0073e9SAndroid Build Coastguard Worker If any checked tensor in :attr:`input` has overlapping memory, i.e., 1997*da0073e9SAndroid Build Coastguard Worker different indices pointing to the same memory address (e.g., from 1998*da0073e9SAndroid Build Coastguard Worker :func:`torch.expand`), this check will likely fail because the numerical 1999*da0073e9SAndroid Build Coastguard Worker gradients computed by point perturbation at such indices will change 2000*da0073e9SAndroid Build Coastguard Worker values at all other indices that share the same memory address. 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker Args: 2003*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 2004*da0073e9SAndroid Build Coastguard Worker a Tensor or a tuple of Tensors 2005*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensor or Tensor): inputs to the function 2006*da0073e9SAndroid Build Coastguard Worker eps (float, optional): perturbation for finite differences 2007*da0073e9SAndroid Build Coastguard Worker atol (float, optional): absolute tolerance 2008*da0073e9SAndroid Build Coastguard Worker rtol (float, optional): relative tolerance 2009*da0073e9SAndroid Build Coastguard Worker raise_exception (bool, optional): indicating whether to raise an exception if 2010*da0073e9SAndroid Build Coastguard Worker the check fails. The exception gives more information about the 2011*da0073e9SAndroid Build Coastguard Worker exact nature of the failure. This is helpful when debugging gradchecks. 2012*da0073e9SAndroid Build Coastguard Worker nondet_tol (float, optional): tolerance for non-determinism. When running 2013*da0073e9SAndroid Build Coastguard Worker identical inputs through the differentiation, the results must either match 2014*da0073e9SAndroid Build Coastguard Worker exactly (default, 0.0) or be within this tolerance. 2015*da0073e9SAndroid Build Coastguard Worker check_undefined_grad (bool, optional): if ``True``, check if undefined output grads 2016*da0073e9SAndroid Build Coastguard Worker are supported and treated as zeros, for ``Tensor`` outputs. 2017*da0073e9SAndroid Build Coastguard Worker check_batched_grad (bool, optional): if ``True``, check if we can compute 2018*da0073e9SAndroid Build Coastguard Worker batched gradients using prototype vmap support. Defaults to False. 2019*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute 2020*da0073e9SAndroid Build Coastguard Worker batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``. 2021*da0073e9SAndroid Build Coastguard Worker check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward 2022*da0073e9SAndroid Build Coastguard Worker mode AD match the numerical ones. Defaults to ``False``. 2023*da0073e9SAndroid Build Coastguard Worker check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on 2024*da0073e9SAndroid Build Coastguard Worker backward mode AD to be implemented. Defaults to ``True``. 2025*da0073e9SAndroid Build Coastguard Worker fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only 2026*da0073e9SAndroid Build Coastguard Worker implemented for R to R functions. If none of the inputs and outputs are complex 2027*da0073e9SAndroid Build Coastguard Worker a faster implementation of gradcheck that no longer computes the entire jacobian 2028*da0073e9SAndroid Build Coastguard Worker is run; otherwise, we fall back to the slow implementation. 2029*da0073e9SAndroid Build Coastguard Worker masked (bool, optional): if ``True``, the gradients of unspecified elements of 2030*da0073e9SAndroid Build Coastguard Worker sparse tensors are ignored. Defaults to ``False``. 2031*da0073e9SAndroid Build Coastguard Worker Returns: 2032*da0073e9SAndroid Build Coastguard Worker ``True`` if all differences satisfy allclose condition 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker """ 2035*da0073e9SAndroid Build Coastguard Worker assert ( 2036*da0073e9SAndroid Build Coastguard Worker check_forward_ad or check_backward_ad 2037*da0073e9SAndroid Build Coastguard Worker ), "Expected at least one of check_forward_ad or check_backward_ad to be True" 2038*da0073e9SAndroid Build Coastguard Worker assert not ( 2039*da0073e9SAndroid Build Coastguard Worker check_batched_grad and not check_backward_ad 2040*da0073e9SAndroid Build Coastguard Worker ), "Setting check_batched_grad=True requires check_backward_ad to be True" 2041*da0073e9SAndroid Build Coastguard Worker assert not ( 2042*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad and not check_forward_ad 2043*da0073e9SAndroid Build Coastguard Worker ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" 2044*da0073e9SAndroid Build Coastguard Worker args = locals().copy() 2045*da0073e9SAndroid Build Coastguard Worker args.pop("raise_exception") 2046*da0073e9SAndroid Build Coastguard Worker if not raise_exception: 2047*da0073e9SAndroid Build Coastguard Worker try: 2048*da0073e9SAndroid Build Coastguard Worker return _gradcheck_helper(**args) 2049*da0073e9SAndroid Build Coastguard Worker except GradcheckError as e: 2050*da0073e9SAndroid Build Coastguard Worker return False 2051*da0073e9SAndroid Build Coastguard Worker else: 2052*da0073e9SAndroid Build Coastguard Worker return _gradcheck_helper(**args) 2053*da0073e9SAndroid Build Coastguard Worker 2054*da0073e9SAndroid Build Coastguard Worker 2055*da0073e9SAndroid Build Coastguard Workerdef _gradcheck_helper( 2056*da0073e9SAndroid Build Coastguard Worker func, 2057*da0073e9SAndroid Build Coastguard Worker inputs, 2058*da0073e9SAndroid Build Coastguard Worker eps, 2059*da0073e9SAndroid Build Coastguard Worker atol, 2060*da0073e9SAndroid Build Coastguard Worker rtol, 2061*da0073e9SAndroid Build Coastguard Worker nondet_tol, 2062*da0073e9SAndroid Build Coastguard Worker check_undefined_grad, 2063*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 2064*da0073e9SAndroid Build Coastguard Worker check_batched_grad, 2065*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad, 2066*da0073e9SAndroid Build Coastguard Worker check_forward_ad, 2067*da0073e9SAndroid Build Coastguard Worker check_backward_ad, 2068*da0073e9SAndroid Build Coastguard Worker fast_mode, 2069*da0073e9SAndroid Build Coastguard Worker masked, 2070*da0073e9SAndroid Build Coastguard Worker): 2071*da0073e9SAndroid Build Coastguard Worker tupled_inputs = _as_tuple(inputs) 2072*da0073e9SAndroid Build Coastguard Worker _check_inputs(tupled_inputs) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker func_out = func(*tupled_inputs) 2075*da0073e9SAndroid Build Coastguard Worker outputs = _differentiable_outputs(func_out) 2076*da0073e9SAndroid Build Coastguard Worker _check_outputs(outputs) 2077*da0073e9SAndroid Build Coastguard Worker 2078*da0073e9SAndroid Build Coastguard Worker gradcheck_fn = functools.partial( 2079*da0073e9SAndroid Build Coastguard Worker _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked 2080*da0073e9SAndroid Build Coastguard Worker ) 2081*da0073e9SAndroid Build Coastguard Worker _gradcheck_real_imag( 2082*da0073e9SAndroid Build Coastguard Worker gradcheck_fn, 2083*da0073e9SAndroid Build Coastguard Worker func, 2084*da0073e9SAndroid Build Coastguard Worker func_out, 2085*da0073e9SAndroid Build Coastguard Worker tupled_inputs, 2086*da0073e9SAndroid Build Coastguard Worker outputs, 2087*da0073e9SAndroid Build Coastguard Worker eps, 2088*da0073e9SAndroid Build Coastguard Worker rtol, 2089*da0073e9SAndroid Build Coastguard Worker atol, 2090*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes, 2091*da0073e9SAndroid Build Coastguard Worker check_forward_ad=check_forward_ad, 2092*da0073e9SAndroid Build Coastguard Worker check_backward_ad=check_backward_ad, 2093*da0073e9SAndroid Build Coastguard Worker nondet_tol=nondet_tol, 2094*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=check_undefined_grad, 2095*da0073e9SAndroid Build Coastguard Worker ) 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker if check_batched_forward_grad: 2098*da0073e9SAndroid Build Coastguard Worker _test_batched_grad_forward_ad(func, tupled_inputs) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker # Short circuit because remaining tests rely on backward AD to be implemented 2101*da0073e9SAndroid Build Coastguard Worker if not check_backward_ad: 2102*da0073e9SAndroid Build Coastguard Worker return True 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker for i, o in enumerate(outputs): 2105*da0073e9SAndroid Build Coastguard Worker if check_batched_grad: 2106*da0073e9SAndroid Build Coastguard Worker _test_batched_grad(tupled_inputs, o, i) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked) 2109*da0073e9SAndroid Build Coastguard Worker 2110*da0073e9SAndroid Build Coastguard Worker if check_undefined_grad and check_backward_ad: 2111*da0073e9SAndroid Build Coastguard Worker _test_undefined_backward_mode(func, outputs, tupled_inputs) 2112*da0073e9SAndroid Build Coastguard Worker return True 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Workerdef gradgradcheck( 2116*da0073e9SAndroid Build Coastguard Worker func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors] 2117*da0073e9SAndroid Build Coastguard Worker inputs: _TensorOrTensors, 2118*da0073e9SAndroid Build Coastguard Worker grad_outputs: Optional[_TensorOrTensors] = None, 2119*da0073e9SAndroid Build Coastguard Worker *, 2120*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-6, 2121*da0073e9SAndroid Build Coastguard Worker atol: float = 1e-5, 2122*da0073e9SAndroid Build Coastguard Worker rtol: float = 1e-3, 2123*da0073e9SAndroid Build Coastguard Worker gen_non_contig_grad_outputs: bool = False, 2124*da0073e9SAndroid Build Coastguard Worker raise_exception: bool = True, 2125*da0073e9SAndroid Build Coastguard Worker nondet_tol: float = 0.0, 2126*da0073e9SAndroid Build Coastguard Worker check_undefined_grad: bool = True, 2127*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes: bool = False, 2128*da0073e9SAndroid Build Coastguard Worker check_batched_grad: bool = False, 2129*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev: bool = False, 2130*da0073e9SAndroid Build Coastguard Worker check_rev_over_rev: bool = True, 2131*da0073e9SAndroid Build Coastguard Worker fast_mode: bool = False, 2132*da0073e9SAndroid Build Coastguard Worker masked: bool = False, 2133*da0073e9SAndroid Build Coastguard Worker) -> bool: # noqa: D400,D205 2134*da0073e9SAndroid Build Coastguard Worker r"""Check gradients of gradients computed via small finite differences 2135*da0073e9SAndroid Build Coastguard Worker against analytical gradients wrt tensors in :attr:`inputs` and 2136*da0073e9SAndroid Build Coastguard Worker :attr:`grad_outputs` that are of floating point or complex type and with 2137*da0073e9SAndroid Build Coastguard Worker ``requires_grad=True``. 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker This function checks that backpropagating through the gradients computed 2140*da0073e9SAndroid Build Coastguard Worker to the given :attr:`grad_outputs` are correct. 2141*da0073e9SAndroid Build Coastguard Worker 2142*da0073e9SAndroid Build Coastguard Worker The check between numerical and analytical gradients uses :func:`~torch.allclose`. 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Worker .. note:: 2145*da0073e9SAndroid Build Coastguard Worker The default values are designed for :attr:`input` and 2146*da0073e9SAndroid Build Coastguard Worker :attr:`grad_outputs` of double precision. This check will likely fail if 2147*da0073e9SAndroid Build Coastguard Worker they are of less precision, e.g., ``FloatTensor``. 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker .. warning:: 2150*da0073e9SAndroid Build Coastguard Worker If any checked tensor in :attr:`input` and :attr:`grad_outputs` has 2151*da0073e9SAndroid Build Coastguard Worker overlapping memory, i.e., different indices pointing to the same memory 2152*da0073e9SAndroid Build Coastguard Worker address (e.g., from :func:`torch.expand`), this check will likely fail 2153*da0073e9SAndroid Build Coastguard Worker because the numerical gradients computed by point perturbation at such 2154*da0073e9SAndroid Build Coastguard Worker indices will change values at all other indices that share the same 2155*da0073e9SAndroid Build Coastguard Worker memory address. 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker Args: 2158*da0073e9SAndroid Build Coastguard Worker func (function): a Python function that takes Tensor inputs and returns 2159*da0073e9SAndroid Build Coastguard Worker a Tensor or a tuple of Tensors 2160*da0073e9SAndroid Build Coastguard Worker inputs (tuple of Tensor or Tensor): inputs to the function 2161*da0073e9SAndroid Build Coastguard Worker grad_outputs (tuple of Tensor or Tensor, optional): The gradients with 2162*da0073e9SAndroid Build Coastguard Worker respect to the function's outputs. 2163*da0073e9SAndroid Build Coastguard Worker eps (float, optional): perturbation for finite differences 2164*da0073e9SAndroid Build Coastguard Worker atol (float, optional): absolute tolerance 2165*da0073e9SAndroid Build Coastguard Worker rtol (float, optional): relative tolerance 2166*da0073e9SAndroid Build Coastguard Worker gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is 2167*da0073e9SAndroid Build Coastguard Worker ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the 2168*da0073e9SAndroid Build Coastguard Worker randomly generated gradient outputs are made to be noncontiguous 2169*da0073e9SAndroid Build Coastguard Worker raise_exception (bool, optional): indicating whether to raise an exception if 2170*da0073e9SAndroid Build Coastguard Worker the check fails. The exception gives more information about the 2171*da0073e9SAndroid Build Coastguard Worker exact nature of the failure. This is helpful when debugging gradchecks. 2172*da0073e9SAndroid Build Coastguard Worker nondet_tol (float, optional): tolerance for non-determinism. When running 2173*da0073e9SAndroid Build Coastguard Worker identical inputs through the differentiation, the results must either match 2174*da0073e9SAndroid Build Coastguard Worker exactly (default, 0.0) or be within this tolerance. Note that a small amount 2175*da0073e9SAndroid Build Coastguard Worker of nondeterminism in the gradient will lead to larger inaccuracies in 2176*da0073e9SAndroid Build Coastguard Worker the second derivative. 2177*da0073e9SAndroid Build Coastguard Worker check_undefined_grad (bool, optional): if True, check if undefined output grads 2178*da0073e9SAndroid Build Coastguard Worker are supported and treated as zeros 2179*da0073e9SAndroid Build Coastguard Worker check_batched_grad (bool, optional): if True, check if we can compute 2180*da0073e9SAndroid Build Coastguard Worker batched gradients using prototype vmap support. Defaults to False. 2181*da0073e9SAndroid Build Coastguard Worker fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that 2182*da0073e9SAndroid Build Coastguard Worker no longer computes the entire jacobian. 2183*da0073e9SAndroid Build Coastguard Worker masked (bool, optional): if True, the gradients of unspecified elements of 2184*da0073e9SAndroid Build Coastguard Worker sparse tensors are ignored (default, False). 2185*da0073e9SAndroid Build Coastguard Worker Returns: 2186*da0073e9SAndroid Build Coastguard Worker True if all differences satisfy allclose condition 2187*da0073e9SAndroid Build Coastguard Worker """ 2188*da0073e9SAndroid Build Coastguard Worker assert ( 2189*da0073e9SAndroid Build Coastguard Worker check_fwd_over_rev or check_rev_over_rev 2190*da0073e9SAndroid Build Coastguard Worker ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" 2191*da0073e9SAndroid Build Coastguard Worker assert not ( 2192*da0073e9SAndroid Build Coastguard Worker check_undefined_grad and not check_rev_over_rev 2193*da0073e9SAndroid Build Coastguard Worker ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" 2194*da0073e9SAndroid Build Coastguard Worker assert not ( 2195*da0073e9SAndroid Build Coastguard Worker check_batched_grad and not check_rev_over_rev 2196*da0073e9SAndroid Build Coastguard Worker ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" 2197*da0073e9SAndroid Build Coastguard Worker # TODO: do we want to test this too? 2198*da0073e9SAndroid Build Coastguard Worker # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( 2199*da0073e9SAndroid Build Coastguard Worker # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") 2200*da0073e9SAndroid Build Coastguard Worker tupled_inputs = _as_tuple(inputs) 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker if grad_outputs is None: 2203*da0073e9SAndroid Build Coastguard Worker # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs 2204*da0073e9SAndroid Build Coastguard Worker 2205*da0073e9SAndroid Build Coastguard Worker outputs = _differentiable_outputs(func(*tupled_inputs)) 2206*da0073e9SAndroid Build Coastguard Worker tupled_grad_outputs = tuple( 2207*da0073e9SAndroid Build Coastguard Worker torch.testing.make_tensor( 2208*da0073e9SAndroid Build Coastguard Worker x.shape, 2209*da0073e9SAndroid Build Coastguard Worker dtype=x.dtype 2210*da0073e9SAndroid Build Coastguard Worker if x.is_floating_point() or x.is_complex() 2211*da0073e9SAndroid Build Coastguard Worker else torch.double, 2212*da0073e9SAndroid Build Coastguard Worker device=x.device, 2213*da0073e9SAndroid Build Coastguard Worker low=-1, 2214*da0073e9SAndroid Build Coastguard Worker high=1, 2215*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 2216*da0073e9SAndroid Build Coastguard Worker noncontiguous=gen_non_contig_grad_outputs, 2217*da0073e9SAndroid Build Coastguard Worker ) 2218*da0073e9SAndroid Build Coastguard Worker for x in outputs 2219*da0073e9SAndroid Build Coastguard Worker ) 2220*da0073e9SAndroid Build Coastguard Worker else: 2221*da0073e9SAndroid Build Coastguard Worker tupled_grad_outputs = _as_tuple(grad_outputs) 2222*da0073e9SAndroid Build Coastguard Worker 2223*da0073e9SAndroid Build Coastguard Worker num_outputs = len(tupled_grad_outputs) 2224*da0073e9SAndroid Build Coastguard Worker 2225*da0073e9SAndroid Build Coastguard Worker # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs 2226*da0073e9SAndroid Build Coastguard Worker # before running forward mode AD 2227*da0073e9SAndroid Build Coastguard Worker diff_input_args_indices = { 2228*da0073e9SAndroid Build Coastguard Worker i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad 2229*da0073e9SAndroid Build Coastguard Worker } 2230*da0073e9SAndroid Build Coastguard Worker diff_grad_output_indices = { 2231*da0073e9SAndroid Build Coastguard Worker i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad 2232*da0073e9SAndroid Build Coastguard Worker } 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker def new_func(*args): 2235*da0073e9SAndroid Build Coastguard Worker # Restore the requires_grad information 2236*da0073e9SAndroid Build Coastguard Worker input_args = tuple( 2237*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() if i in diff_input_args_indices else x 2238*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(args[:-num_outputs]) 2239*da0073e9SAndroid Build Coastguard Worker ) 2240*da0073e9SAndroid Build Coastguard Worker outputs = _differentiable_outputs(func(*input_args)) 2241*da0073e9SAndroid Build Coastguard Worker grad_outputs = tuple( 2242*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() if i in diff_grad_output_indices else x 2243*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(args[-num_outputs:]) 2244*da0073e9SAndroid Build Coastguard Worker ) 2245*da0073e9SAndroid Build Coastguard Worker diff_input_args = tuple( 2246*da0073e9SAndroid Build Coastguard Worker x for i, x in enumerate(input_args) if i in diff_input_args_indices 2247*da0073e9SAndroid Build Coastguard Worker ) 2248*da0073e9SAndroid Build Coastguard Worker grad_inputs = torch.autograd.grad( 2249*da0073e9SAndroid Build Coastguard Worker outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True 2250*da0073e9SAndroid Build Coastguard Worker ) 2251*da0073e9SAndroid Build Coastguard Worker grad_inputs = tuple(g for g in grad_inputs if g is not None) 2252*da0073e9SAndroid Build Coastguard Worker return grad_inputs 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Worker return gradcheck( 2255*da0073e9SAndroid Build Coastguard Worker new_func, 2256*da0073e9SAndroid Build Coastguard Worker tupled_inputs + tupled_grad_outputs, 2257*da0073e9SAndroid Build Coastguard Worker eps=eps, 2258*da0073e9SAndroid Build Coastguard Worker atol=atol, 2259*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 2260*da0073e9SAndroid Build Coastguard Worker raise_exception=raise_exception, 2261*da0073e9SAndroid Build Coastguard Worker nondet_tol=nondet_tol, 2262*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=check_undefined_grad, 2263*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes=check_grad_dtypes, 2264*da0073e9SAndroid Build Coastguard Worker check_batched_grad=check_batched_grad, 2265*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 2266*da0073e9SAndroid Build Coastguard Worker check_forward_ad=check_fwd_over_rev, 2267*da0073e9SAndroid Build Coastguard Worker check_backward_ad=check_rev_over_rev, 2268*da0073e9SAndroid Build Coastguard Worker masked=masked, 2269*da0073e9SAndroid Build Coastguard Worker ) 2270