1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import inspect 4from typing import Callable, Dict, List, Optional, Tuple 5 6import torch 7import torch._decomp 8from torch import Tensor 9from torch._prims_common.wrappers import _maybe_remove_out_wrapper 10 11 12decomposition_table = torch._decomp.decomposition_table 13decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {} 14register_decomposition = torch._decomp.register_decomposition 15aten = torch.ops.aten 16 17# NOTE: [forward-mode AD decompositions mechanism] 18# 19# The mechanism is in VariableType, 20# IF any inputs have forward grad 21# AND there is no forward AD formula implemented 22# AND the functions is actually differentiable 23# run the decomposition 24# See run_jit_decomposition_with_args_for_jvp 25# We currently use python decompositions that we torchscript. 26# 27# Note that we would be building the backward graph at the decomposed level 28# too, but that is OK, because we would've errored out otherwise anyway. 29# 30# TODO: The mechanism we are using to register decompositions doesn't 31# seem to be exclusively used for jvp. So open question here is whether 32# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. 33# If that is the case, we may go down the decomposition path unexpectedly 34# (and possibly produce an unintelligible error) vs erroring out earlier and 35# printing that the forward AD formula is not implemented. 36# 37# The solution to this may be to have a explicitly white list control when 38# to enable the decomposition. 39 40 41def maybe_register_decomposition(op): 42 def decorator(f): 43 try: 44 return register_decomposition(op)(f) 45 except Exception: 46 return f 47 48 return decorator 49 50 51# Functions where we need a special decomposition for jvp but there's another version that 52# should be used more generally (ex. for jvp we need to recompute the mean and variance for 53# the backwards of a normalization function. Without jvp, it should use the saved value) 54decomposition_table_for_jvp = {} 55 56 57def register_decomposition_for_jvp(fn): 58 return register_decomposition(fn, registry=decomposition_table_for_jvp) 59 60 61def _register_jit_decomposition_for_jvp(decomp, use_python=False): 62 if decomp in decomposition_table_for_jvp: 63 decomposition_table_used = decomposition_table_for_jvp 64 elif decomp in decomposition_table: 65 decomposition_table_used = decomposition_table 66 else: 67 raise RuntimeError(f"could not find decomposition for {decomp}") 68 decomp_fn = decomposition_table_used[decomp] 69 70 # `out_wrapper` extends a decompositions signature with 71 # an `out` parameter. However jit will use the unwrapped function's 72 # signature instead so we need to unwrap here to prevent an error 73 decomp_fn = _maybe_remove_out_wrapper(decomp_fn) 74 75 if use_python: 76 decomp_fn = torch.jit.ignore(decomp_fn) 77 sig = inspect.signature(decomp_fn) 78 79 # Create a string wrapping the function from the signature 80 # example output: 81 # def wrapped_decomp(x: torch.Tensor, y: int, z: int): 82 # return decomp_fn(x, y, z) 83 # Thanks copilot! 84 def get_function_def(sig): 85 param_def = [f"{param_str}" for param_str in sig.parameters.values()] 86 param_use = [f"{param_str}" for param_str in sig.parameters.keys()] 87 88 return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" 89 90 f_str = get_function_def(sig) 91 graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph 92 else: 93 graph = torch.jit.script(decomp_fn).graph 94 torch.jit._register_decomposition(decomp, graph) 95 96 97# The only decompositions here are temporary or hacks for the purposes of jvp 98 99 100# TODO: do these also belong here? 101@maybe_register_decomposition(aten.trace.default) 102def trace(self: Tensor) -> Tensor: 103 return torch.sum(torch.diag(self)) 104 105 106@maybe_register_decomposition(aten.log_sigmoid_forward.default) 107def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: 108 min = torch.minimum(self.new_zeros(()), self) 109 z = torch.exp(-torch.abs(self)) 110 if self.is_cuda: 111 buffer = self.new_zeros((0,)) 112 else: 113 buffer = z 114 return min - torch.log1p(z), buffer 115 116 117def recompute_mean_var( 118 input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool 119): 120 # for most norm decompositions, it will be the same as the core version except for here. 121 # We recompute the mean and variance so that they track gradients through input 122 123 mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) 124 var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) 125 eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside 126 eps = eps.detach() 127 rstd = 1 / torch.sqrt(var + eps) 128 return mean, rstd 129 130 131@register_decomposition_for_jvp(aten.native_layer_norm_backward) 132def native_layer_norm_backward( 133 grad_out: Tensor, 134 input: Tensor, 135 normalized_shape: List[int], 136 mean: Tensor, 137 rstd: Tensor, 138 weight: Optional[Tensor], 139 bias: Optional[Tensor], 140 output_mask: List[bool], 141) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: 142 input_shape = input.shape 143 input_ndim = input.dim() 144 145 axis = input_ndim - len(normalized_shape) 146 inner_dims = input_shape[axis:] 147 outer_dims = input_shape[:axis] 148 inner_dim_indices = list(range(axis, input_ndim)) 149 outer_dim_indices = list(range(0, axis)) 150 151 N = 1 152 for i in inner_dims: 153 N *= i 154 M = 1 155 for i in outer_dims: 156 M *= i 157 if M <= 0 or N <= 0: 158 return ( 159 input.new_zeros(input_shape), 160 input.new_zeros(input_shape[axis:]), 161 input.new_zeros(input_shape[axis:]), 162 ) 163 164 mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) 165 166 x_hat = (input - mean_) * rstd_ 167 if weight is not None: 168 grad_x_hat = grad_out * weight 169 else: 170 grad_x_hat = grad_out 171 a = grad_x_hat * N 172 b = torch.sum(grad_x_hat, inner_dim_indices, True) 173 c1 = torch.mul(grad_x_hat, x_hat) 174 c2 = torch.sum(c1, inner_dim_indices, True) 175 c3 = torch.mul(x_hat, c2) 176 inner = a - b - c3 177 178 if output_mask[0]: 179 d_input: Optional[Tensor] = (rstd_ / N) * inner 180 else: 181 d_input = torch.zeros_like(input) # should be None but doesn't work with vjp 182 183 if output_mask[1] and weight is not None: 184 if len(outer_dim_indices) > 0: 185 d_weight: Optional[Tensor] = torch.sum( 186 grad_out * x_hat, outer_dim_indices, False 187 ) 188 else: 189 d_weight = grad_out * x_hat 190 elif weight is not None: 191 d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp 192 else: 193 d_weight = torch.zeros(()) # should be None but doesn't work with vjp 194 195 if output_mask[2] and bias is not None: 196 if len(outer_dim_indices) > 0: 197 d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) 198 else: 199 d_bias = grad_out.clone() 200 elif bias is not None: 201 d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp 202 else: 203 d_bias = torch.zeros(()) # should be None but doesn't work with vjp 204 205 return (d_input, d_weight, d_bias) 206 207 208def prod(x: List[int]): 209 r = 1 210 for i in x: 211 r *= i 212 return r 213 214 215@register_decomposition_for_jvp(aten.native_batch_norm_backward) 216def native_batch_norm_backward( 217 grad_out: Tensor, 218 input: Tensor, 219 weight: Optional[Tensor], 220 running_mean: Optional[Tensor], 221 running_var: Optional[Tensor], 222 save_mean: Optional[Tensor], 223 save_invstd: Optional[Tensor], 224 train: bool, 225 eps: float, 226 output_mask: List[bool], 227) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 228 input_shape = input.shape 229 input_rank = input.dim() 230 assert input_rank >= 2, "rank of the input must be at least 2" 231 232 axis = 1 233 num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] 234 mean = save_mean 235 invstd = save_invstd 236 if train: 237 assert ( 238 save_mean is not None and save_invstd is not None 239 ), "when train=True, save_mean and save_invstd are required" 240 241 reduciton_dims = [0] + list(range(2, input.dim())) 242 assert invstd is not None # for typing 243 mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) 244 else: 245 assert running_mean is not None and running_var is not None 246 mean = running_mean 247 invstd = torch.rsqrt(running_var + eps) 248 249 assert invstd is not None and mean is not None 250 251 broadcast_mask = [1] * input_rank 252 broadcast_mask[axis] = input_shape[axis] 253 254 reduction_axes: List[int] = [] 255 for i in range(input_rank): 256 if i != axis: 257 reduction_axes.append(i) 258 259 mean = torch.reshape(mean, broadcast_mask) 260 norm = 1.0 / num_features 261 grad_output_sum = torch.sum(grad_out, reduction_axes) 262 dot_p = torch.sum(grad_out * (input - mean), reduction_axes) 263 264 grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) 265 proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) 266 267 if weight is None: 268 grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 269 else: 270 grad_scale = torch.reshape(invstd * weight, broadcast_mask) 271 272 if train: 273 proj = (input - mean) * proj_scale 274 grad_input = ((grad_out - proj) - grad_mean) * grad_scale 275 else: 276 grad_input = grad_out * grad_scale 277 278 if output_mask[1]: 279 grad_weight = dot_p * invstd 280 elif weight is not None: 281 grad_weight = torch.zeros_like( 282 weight 283 ) # should be None but doesn't work with vjp 284 else: 285 grad_weight = torch.zeros(()) # should be None but doesn't work with vjp 286 287 if output_mask[2]: 288 grad_bias = grad_output_sum 289 else: 290 grad_bias = torch.zeros_like( 291 grad_output_sum 292 ) # should be None but doesn't work with vjp 293 294 return (grad_input, grad_weight, grad_bias) 295 296 297@register_decomposition_for_jvp(aten.batch_norm_backward) 298def batch_norm_backward( 299 grad_out: Tensor, 300 input: Tensor, 301 weight: Tensor, 302 running_mean: Optional[Tensor], 303 running_var: Optional[Tensor], 304 save_mean: Optional[Tensor], 305 save_var: Optional[Tensor], 306 update: bool, 307 eps: float, 308 output_mask: List[bool], 309 reserve: Tensor, 310) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 311 return native_batch_norm_backward( 312 grad_out, 313 input, 314 weight, 315 running_mean, 316 running_var, 317 save_mean, 318 save_var, 319 update, 320 eps, 321 output_mask, 322 ) 323 324 325_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) 326_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) 327_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) 328_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) 329_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) 330_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) 331_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) 332_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) 333_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) 334_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) 335_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default) 336