1# mypy: ignore-errors 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import contextlib 10import functools 11import itertools 12import os 13import threading 14from functools import partial 15from typing import Any, Callable, List, Optional, Tuple, Union 16 17import torch 18from torch import Tensor 19from torch._C._functorch import ( 20 _add_batch_dim, 21 _remove_batch_dim, 22 _vmap_decrement_nesting, 23 _vmap_increment_nesting, 24 is_batchedtensor, 25) 26from torch.utils._pytree import ( 27 _broadcast_to_and_flatten, 28 tree_flatten, 29 tree_map_, 30 tree_unflatten, 31 TreeSpec, 32) 33 34 35in_dims_t = Union[int, Tuple] 36out_dims_t = Union[int, Tuple[int, ...]] 37 38 39def doesnt_support_saved_tensors_hooks(f): 40 message = ( 41 "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. " 42 "Please open an issue with your use case." 43 ) 44 45 @functools.wraps(f) 46 def fn(*args, **kwargs): 47 with torch.autograd.graph.disable_saved_tensors_hooks(message): 48 return f(*args, **kwargs) 49 50 return fn 51 52 53# Checks that all args-to-be-batched have the same batch dim size 54def _validate_and_get_batch_size( 55 flat_in_dims: List[Optional[int]], flat_args: List 56) -> int: 57 batch_sizes = [ 58 arg.size(in_dim) 59 for in_dim, arg in zip(flat_in_dims, flat_args) 60 if in_dim is not None 61 ] 62 if len(batch_sizes) == 0: 63 raise ValueError("vmap: Expected at least one Tensor to vmap over") 64 if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): 65 raise ValueError( 66 f"vmap: Expected all tensors to have the same size in the mapped " 67 f"dimension, got sizes {batch_sizes} for the mapped dimension" 68 ) 69 return batch_sizes[0] 70 71 72def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: 73 if isinstance(batched_outputs, tuple): 74 return len(batched_outputs) 75 return 1 76 77 78# If value is a tuple, check it has length `num_elements`. 79# If value is not a tuple, make a tuple with `value` repeated `num_elements` times 80 81 82def _as_tuple( 83 value: Any, num_elements: int, error_message_lambda: Callable[[], str] 84) -> Tuple: 85 if not isinstance(value, tuple): 86 return (value,) * num_elements 87 if len(value) != num_elements: 88 raise ValueError(error_message_lambda()) 89 return value 90 91 92def _process_batched_inputs( 93 in_dims: in_dims_t, args: Tuple, func: Callable 94) -> Tuple[int, List[Any], List[Any], TreeSpec]: 95 if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): 96 raise ValueError( 97 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 98 f"expected `in_dims` to be int or a (potentially nested) tuple " 99 f"matching the structure of inputs, got: {type(in_dims)}." 100 ) 101 if len(args) == 0: 102 raise ValueError( 103 f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add " 104 f"inputs, or you are trying to vmap over a function with no inputs. " 105 f"The latter is unsupported." 106 ) 107 108 flat_args, args_spec = tree_flatten(args) 109 flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) 110 if flat_in_dims is None: 111 raise ValueError( 112 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 113 f"in_dims is not compatible with the structure of `inputs`. " 114 f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " 115 f"has structure {args_spec}." 116 ) 117 118 for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): 119 if not isinstance(in_dim, int) and in_dim is not None: 120 raise ValueError( 121 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 122 f"Got in_dim={in_dim} for an input but in_dim must be either " 123 f"an integer dimension or None." 124 ) 125 if isinstance(in_dim, int) and not isinstance(arg, Tensor): 126 raise ValueError( 127 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 128 f"Got in_dim={in_dim} for an input but the input is of type " 129 f"{type(arg)}. We cannot vmap over non-Tensor arguments, " 130 f"please use None as the respective in_dim" 131 ) 132 if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): 133 raise ValueError( 134 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 135 f"Got in_dim={in_dim} for some input, but that input is a Tensor " 136 f"of dimensionality {arg.dim()} so expected in_dim to satisfy " 137 f"-{arg.dim()} <= in_dim < {arg.dim()}." 138 ) 139 if in_dim is not None and in_dim < 0: 140 flat_in_dims[i] = in_dim % arg.dim() 141 142 return ( 143 _validate_and_get_batch_size(flat_in_dims, flat_args), 144 flat_in_dims, 145 flat_args, 146 args_spec, 147 ) 148 149 150# Creates BatchedTensors for every Tensor in arg that should be batched. 151# Returns the (potentially) batched arguments and the batch_size. 152 153 154def _create_batched_inputs( 155 flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec 156) -> Tuple: 157 # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] 158 batched_inputs = [ 159 arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level) 160 for in_dim, arg in zip(flat_in_dims, flat_args) 161 ] 162 return tree_unflatten(batched_inputs, args_spec) 163 164 165def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim): 166 if out_dim is None: 167 if isinstance(batched_output, torch.Tensor) and is_batchedtensor( 168 batched_output 169 ): 170 raise ValueError( 171 f"vmap({name}, ...): `{name}` can not return a " 172 f"BatchedTensor when out_dim is None" 173 ) 174 return batched_output 175 176 # out_dim is non None 177 if not isinstance(batched_output, torch.Tensor): 178 raise ValueError( 179 f"vmap({name}, ...): `{name}` must only return " 180 f"Tensors, got type {type(batched_output)}. " 181 "Did you mean to set out_dims= to None for output?" 182 ) 183 184 return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) 185 186 187# Undos the batching (and any batch dimensions) associated with the `vmap_level`. 188def _unwrap_batched( 189 batched_outputs: Union[Tensor, Tuple[Tensor, ...]], 190 out_dims: out_dims_t, 191 vmap_level: int, 192 batch_size: int, 193 func: Callable, 194) -> Tuple: 195 flat_batched_outputs, output_spec = tree_flatten(batched_outputs) 196 197 def incompatible_error(): 198 raise ValueError( 199 f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): " 200 f"out_dims is not compatible with the structure of `outputs`. " 201 f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " 202 f"has structure {output_spec}." 203 ) 204 205 if isinstance(batched_outputs, torch.Tensor): 206 # Some weird edge case requires us to spell out the following 207 # see test_out_dims_edge_case 208 if isinstance(out_dims, int): 209 flat_out_dims = [out_dims] 210 elif isinstance(out_dims, tuple) and len(out_dims) == 1: 211 flat_out_dims = out_dims 212 elif out_dims is None: 213 flat_out_dims = [out_dims] 214 else: 215 incompatible_error() 216 else: 217 flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) 218 if flat_out_dims is None: 219 incompatible_error() 220 221 flat_outputs = [ 222 _maybe_remove_batch_dim( 223 _get_name(func), batched_output, vmap_level, batch_size, out_dim 224 ) 225 for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) 226 ] 227 return tree_unflatten(flat_outputs, output_spec) 228 229 230def _check_int_or_none(x, func, out_dims): 231 if isinstance(x, int): 232 return 233 if x is None: 234 return 235 raise ValueError( 236 f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " 237 f"an int, None or a python collection of ints representing where in the outputs the " 238 f"vmapped dimension should appear." 239 ) 240 241 242def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None: 243 if isinstance(out_dims, int): 244 return 245 tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims) 246 247 248def _get_name(func: Callable): 249 if hasattr(func, "__name__"): 250 return func.__name__ 251 252 # Not all callables have __name__, in fact, only static functions/methods do. 253 # A callable created via functools.partial or an nn.Module, to name some 254 # examples, don't have a __name__. 255 return repr(func) 256 257 258DECOMPOSITIONS_LOADED = False 259DECOMPOSITIONS_LOCK = threading.Lock() 260VMAP_DECOMPOSITIONS_LIB = None 261 262 263# torch.package, Python 3.11, and torch.jit-less environments are unhappy with 264# decompositions. Only load them when needed if possible. 265def lazy_load_decompositions(): 266 global DECOMPOSITIONS_LOADED 267 if DECOMPOSITIONS_LOADED: 268 return 269 270 with DECOMPOSITIONS_LOCK: 271 if DECOMPOSITIONS_LOADED: 272 return 273 274 if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__): 275 DECOMPOSITIONS_LOADED = True 276 return 277 278 # use an alternate way to register an operator into the decomposition table 279 # _register_jit_decomposition doesn't work for some operators, e.g. addr, 280 # because the Tensor types generated cannot be unioned by torchscript 281 # decomp should be type OpOverload 282 global VMAP_DECOMPOSITIONS_LIB 283 VMAP_DECOMPOSITIONS_LIB = torch.library.Library( 284 "aten", "IMPL", "FuncTorchBatched" 285 ) 286 287 from torch._decomp import decomposition_table 288 289 def _register_python_decomposition_vmap(decomp): 290 if decomp in decomposition_table: 291 VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp]) 292 else: 293 raise RuntimeError(f"could not find decomposition for {decomp}") 294 295 _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) 296 _register_python_decomposition_vmap( 297 torch.ops.aten.smooth_l1_loss_backward.default 298 ) 299 _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default) 300 _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default) 301 _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default) 302 _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default) 303 _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default) 304 _register_python_decomposition_vmap(torch.ops.aten.addr.default) 305 306 DECOMPOSITIONS_LOADED = True 307 308 309def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): 310 lazy_load_decompositions() 311 _check_out_dims_is_int_or_int_pytree(out_dims, func) 312 batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs( 313 in_dims, args, func 314 ) 315 316 if chunk_size is not None: 317 chunks_flat_args = _get_chunked_inputs( 318 flat_args, flat_in_dims, batch_size, chunk_size 319 ) 320 return _chunked_vmap( 321 func, 322 flat_in_dims, 323 chunks_flat_args, 324 args_spec, 325 out_dims, 326 randomness, 327 **kwargs, 328 ) 329 330 # If chunk_size is not specified. 331 return _flat_vmap( 332 func, 333 batch_size, 334 flat_in_dims, 335 flat_args, 336 args_spec, 337 out_dims, 338 randomness, 339 **kwargs, 340 ) 341 342 343def get_chunk_sizes(total_elems, chunk_size): 344 n_chunks = n_chunks = total_elems // chunk_size 345 chunk_sizes = [chunk_size] * n_chunks 346 # remainder chunk 347 remainder = total_elems % chunk_size 348 if remainder != 0: 349 chunk_sizes.append(remainder) 350 return chunk_sizes 351 352 353def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size): 354 split_idxs = (batch_size,) 355 if chunk_size is not None: 356 chunk_sizes = get_chunk_sizes(batch_size, chunk_size) 357 split_idxs = tuple(itertools.accumulate(chunk_sizes)) 358 359 flat_args_chunks = tuple( 360 t.tensor_split(split_idxs, dim=in_dim) 361 if in_dim is not None 362 else [ 363 t, 364 ] 365 * len(split_idxs) 366 for t, in_dim in zip(flat_args, flat_in_dims) 367 ) 368 369 # transpose chunk dim and flatten structure 370 # chunks_flat_args is a list of flatten args 371 chunks_flat_args = zip(*flat_args_chunks) 372 return chunks_flat_args 373 374 375def _flatten_chunks_output(chunks_output_): 376 # chunks_output is a list of chunked outputs 377 # flatten chunked outputs: 378 flat_chunks_output = [] 379 arg_spec = None 380 for output in chunks_output_: 381 flat_output, arg_specs = tree_flatten(output) 382 flat_chunks_output.append(flat_output) 383 if arg_spec is None: 384 arg_spec = arg_specs 385 386 # transpose chunk dim and flatten structure 387 # flat_output_chunks is flat list of chunks 388 flat_output_chunks = list(zip(*flat_chunks_output)) 389 return flat_output_chunks, arg_spec 390 391 392def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks): 393 # concat chunks on out_dim 394 flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) 395 assert len(flat_out_dims) == len(flat_output_chunks) 396 flat_output = [] 397 for idx, out_dim in enumerate(flat_out_dims): 398 flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim)) 399 # release tensors 400 flat_output_chunks[idx] = None 401 402 return flat_output 403 404 405# Applies vmap on chunked_input and returns concatenated output over the chunks. 406def _chunked_vmap( 407 func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs 408): 409 chunks_output = [] 410 rs = torch.get_rng_state() if randomness == "same" else None 411 for flat_args in chunks_flat_args: 412 batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) 413 414 # The way we compute split the input in `_get_chunked_inputs`, 415 # we may get a tensor with `0` batch-size. We skip any computation 416 # in that case. 417 # Eg. 418 # >>> chunk_size = 1 419 # >>> batch_size = 6 420 # >>> t = torch.zeros(batch_size, 1) 421 # >>> t.tensor_split([1, 2, 3, 4, 5, 6]) 422 # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), 423 # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1))) 424 if batch_size == 0: 425 continue 426 427 if rs is not None: 428 torch.set_rng_state(rs) 429 chunks_output.append( 430 _flat_vmap( 431 func, 432 batch_size, 433 flat_in_dims, 434 flat_args, 435 args_spec, 436 out_dims, 437 randomness, 438 **kwargs, 439 ) 440 ) 441 442 flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) 443 444 # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`. 445 # eagerly remove the reference from `chunks_output`. 446 del chunks_output 447 448 # concat chunks on out_dim 449 flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks) 450 451 # finally unflatten the output 452 return tree_unflatten(flat_output, arg_spec) 453 454 455# Vmap refactored helper functions: 456def _check_randomness_arg(randomness): 457 if randomness not in ["error", "different", "same"]: 458 raise RuntimeError( 459 f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}" 460 ) 461 462 463@contextlib.contextmanager 464def vmap_increment_nesting(batch_size, randomness): 465 try: 466 vmap_level = _vmap_increment_nesting(batch_size, randomness) 467 yield vmap_level 468 finally: 469 _vmap_decrement_nesting() 470 471 472def _flat_vmap( 473 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs 474): 475 with vmap_increment_nesting(batch_size, randomness) as vmap_level: 476 batched_inputs = _create_batched_inputs( 477 flat_in_dims, flat_args, vmap_level, args_spec 478 ) 479 batched_outputs = func(*batched_inputs, **kwargs) 480 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) 481 482 483# `restore_vmap` is a private helper function. It is vmap but has the following 484# differences: 485# - instead of returning outputs, it returns an (outputs, out_dims) tuple. 486# out_dims is a pytree of same shape as outputs and contains Optional[int] 487# specifying where the vmapped dimension, if it exists, is in the corresponding output. 488# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). 489# restore_vmap allows for no inputs to have the vmap dimension 490# - does no validation on outputs (vmap expects only Tensor outputs) 491# restore_vmap allows for return of arbitrary outputs (not just Tensors) 492# 493# The TL;DR is that restore_vmap is more general than vmap and has a slightly 494# different API. The relaxations are so that we can "pause" vmap in the middle 495# of its execution and then "restore" it later (this is what we do in 496# the generate_vmap_rule=True implementation of autograd.Function). 497# 498# restore_vmap can be technically used in the implementation of vmap, but doing 499# that refactor is a bit technically challenging because: 500# - vmap couples the tensor-wrapping code with error checking 501# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it 502# in python because it overlaps with unwrap_batched 503def restore_vmap(func, in_dims, batch_size, randomness): 504 def inner(*args, **kwargs): 505 with vmap_increment_nesting(batch_size, randomness) as vmap_level: 506 batched_inputs = wrap_batched(args, in_dims, vmap_level) 507 batched_outputs = func(*batched_inputs, **kwargs) 508 return unwrap_batched(batched_outputs, vmap_level) 509 510 return inner 511 512 513def wrap_batched(args, bdims, level): 514 flat_args, spec = tree_flatten(args) 515 flat_bdims = _broadcast_to_and_flatten(bdims, spec) 516 assert flat_bdims is not None 517 result = _create_batched_inputs(flat_bdims, flat_args, level, spec) 518 return result 519 520 521def unwrap_batched(args, level): 522 flat_args, spec = tree_flatten(args) 523 if len(flat_args) == 0: 524 return args, () 525 result = [ 526 torch._C._functorch._unwrap_batched(arg, level) 527 if isinstance(arg, torch.Tensor) 528 else (arg, None) 529 for arg in flat_args 530 ] 531 output, bdims = zip(*result) 532 return tree_unflatten(output, spec), tree_unflatten(bdims, spec) 533