xref: /aosp_15_r20/external/pytorch/torch/_functorch/vmap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import contextlib
10import functools
11import itertools
12import os
13import threading
14from functools import partial
15from typing import Any, Callable, List, Optional, Tuple, Union
16
17import torch
18from torch import Tensor
19from torch._C._functorch import (
20    _add_batch_dim,
21    _remove_batch_dim,
22    _vmap_decrement_nesting,
23    _vmap_increment_nesting,
24    is_batchedtensor,
25)
26from torch.utils._pytree import (
27    _broadcast_to_and_flatten,
28    tree_flatten,
29    tree_map_,
30    tree_unflatten,
31    TreeSpec,
32)
33
34
35in_dims_t = Union[int, Tuple]
36out_dims_t = Union[int, Tuple[int, ...]]
37
38
39def doesnt_support_saved_tensors_hooks(f):
40    message = (
41        "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. "
42        "Please open an issue with your use case."
43    )
44
45    @functools.wraps(f)
46    def fn(*args, **kwargs):
47        with torch.autograd.graph.disable_saved_tensors_hooks(message):
48            return f(*args, **kwargs)
49
50    return fn
51
52
53# Checks that all args-to-be-batched have the same batch dim size
54def _validate_and_get_batch_size(
55    flat_in_dims: List[Optional[int]], flat_args: List
56) -> int:
57    batch_sizes = [
58        arg.size(in_dim)
59        for in_dim, arg in zip(flat_in_dims, flat_args)
60        if in_dim is not None
61    ]
62    if len(batch_sizes) == 0:
63        raise ValueError("vmap: Expected at least one Tensor to vmap over")
64    if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
65        raise ValueError(
66            f"vmap: Expected all tensors to have the same size in the mapped "
67            f"dimension, got sizes {batch_sizes} for the mapped dimension"
68        )
69    return batch_sizes[0]
70
71
72def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
73    if isinstance(batched_outputs, tuple):
74        return len(batched_outputs)
75    return 1
76
77
78# If value is a tuple, check it has length `num_elements`.
79# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
80
81
82def _as_tuple(
83    value: Any, num_elements: int, error_message_lambda: Callable[[], str]
84) -> Tuple:
85    if not isinstance(value, tuple):
86        return (value,) * num_elements
87    if len(value) != num_elements:
88        raise ValueError(error_message_lambda())
89    return value
90
91
92def _process_batched_inputs(
93    in_dims: in_dims_t, args: Tuple, func: Callable
94) -> Tuple[int, List[Any], List[Any], TreeSpec]:
95    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
96        raise ValueError(
97            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
98            f"expected `in_dims` to be int or a (potentially nested) tuple "
99            f"matching the structure of inputs, got: {type(in_dims)}."
100        )
101    if len(args) == 0:
102        raise ValueError(
103            f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
104            f"inputs, or you are trying to vmap over a function with no inputs. "
105            f"The latter is unsupported."
106        )
107
108    flat_args, args_spec = tree_flatten(args)
109    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
110    if flat_in_dims is None:
111        raise ValueError(
112            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
113            f"in_dims is not compatible with the structure of `inputs`. "
114            f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
115            f"has structure {args_spec}."
116        )
117
118    for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
119        if not isinstance(in_dim, int) and in_dim is not None:
120            raise ValueError(
121                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
122                f"Got in_dim={in_dim} for an input but in_dim must be either "
123                f"an integer dimension or None."
124            )
125        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
126            raise ValueError(
127                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
128                f"Got in_dim={in_dim} for an input but the input is of type "
129                f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
130                f"please use None as the respective in_dim"
131            )
132        if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
133            raise ValueError(
134                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
135                f"Got in_dim={in_dim} for some input, but that input is a Tensor "
136                f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
137                f"-{arg.dim()} <= in_dim < {arg.dim()}."
138            )
139        if in_dim is not None and in_dim < 0:
140            flat_in_dims[i] = in_dim % arg.dim()
141
142    return (
143        _validate_and_get_batch_size(flat_in_dims, flat_args),
144        flat_in_dims,
145        flat_args,
146        args_spec,
147    )
148
149
150# Creates BatchedTensors for every Tensor in arg that should be batched.
151# Returns the (potentially) batched arguments and the batch_size.
152
153
154def _create_batched_inputs(
155    flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec
156) -> Tuple:
157    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
158    batched_inputs = [
159        arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
160        for in_dim, arg in zip(flat_in_dims, flat_args)
161    ]
162    return tree_unflatten(batched_inputs, args_spec)
163
164
165def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
166    if out_dim is None:
167        if isinstance(batched_output, torch.Tensor) and is_batchedtensor(
168            batched_output
169        ):
170            raise ValueError(
171                f"vmap({name}, ...): `{name}` can not return a "
172                f"BatchedTensor when out_dim is None"
173            )
174        return batched_output
175
176    # out_dim is non None
177    if not isinstance(batched_output, torch.Tensor):
178        raise ValueError(
179            f"vmap({name}, ...): `{name}` must only return "
180            f"Tensors, got type {type(batched_output)}. "
181            "Did you mean to set out_dims= to None for output?"
182        )
183
184    return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
185
186
187# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
188def _unwrap_batched(
189    batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
190    out_dims: out_dims_t,
191    vmap_level: int,
192    batch_size: int,
193    func: Callable,
194) -> Tuple:
195    flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
196
197    def incompatible_error():
198        raise ValueError(
199            f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
200            f"out_dims is not compatible with the structure of `outputs`. "
201            f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
202            f"has structure {output_spec}."
203        )
204
205    if isinstance(batched_outputs, torch.Tensor):
206        # Some weird edge case requires us to spell out the following
207        # see test_out_dims_edge_case
208        if isinstance(out_dims, int):
209            flat_out_dims = [out_dims]
210        elif isinstance(out_dims, tuple) and len(out_dims) == 1:
211            flat_out_dims = out_dims
212        elif out_dims is None:
213            flat_out_dims = [out_dims]
214        else:
215            incompatible_error()
216    else:
217        flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
218        if flat_out_dims is None:
219            incompatible_error()
220
221    flat_outputs = [
222        _maybe_remove_batch_dim(
223            _get_name(func), batched_output, vmap_level, batch_size, out_dim
224        )
225        for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
226    ]
227    return tree_unflatten(flat_outputs, output_spec)
228
229
230def _check_int_or_none(x, func, out_dims):
231    if isinstance(x, int):
232        return
233    if x is None:
234        return
235    raise ValueError(
236        f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
237        f"an int, None or a python collection of ints representing where in the outputs the "
238        f"vmapped dimension should appear."
239    )
240
241
242def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
243    if isinstance(out_dims, int):
244        return
245    tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
246
247
248def _get_name(func: Callable):
249    if hasattr(func, "__name__"):
250        return func.__name__
251
252    # Not all callables have __name__, in fact, only static functions/methods do.
253    # A callable created via functools.partial or an nn.Module, to name some
254    # examples, don't have a __name__.
255    return repr(func)
256
257
258DECOMPOSITIONS_LOADED = False
259DECOMPOSITIONS_LOCK = threading.Lock()
260VMAP_DECOMPOSITIONS_LIB = None
261
262
263# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
264# decompositions. Only load them when needed if possible.
265def lazy_load_decompositions():
266    global DECOMPOSITIONS_LOADED
267    if DECOMPOSITIONS_LOADED:
268        return
269
270    with DECOMPOSITIONS_LOCK:
271        if DECOMPOSITIONS_LOADED:
272            return
273
274        if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
275            DECOMPOSITIONS_LOADED = True
276            return
277
278        # use an alternate way to register an operator into the decomposition table
279        # _register_jit_decomposition doesn't work for some operators, e.g. addr,
280        #  because the Tensor types generated cannot be unioned by torchscript
281        # decomp should be type OpOverload
282        global VMAP_DECOMPOSITIONS_LIB
283        VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
284            "aten", "IMPL", "FuncTorchBatched"
285        )
286
287        from torch._decomp import decomposition_table
288
289        def _register_python_decomposition_vmap(decomp):
290            if decomp in decomposition_table:
291                VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
292            else:
293                raise RuntimeError(f"could not find decomposition for {decomp}")
294
295        _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
296        _register_python_decomposition_vmap(
297            torch.ops.aten.smooth_l1_loss_backward.default
298        )
299        _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
300        _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
301        _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
302        _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
303        _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
304        _register_python_decomposition_vmap(torch.ops.aten.addr.default)
305
306        DECOMPOSITIONS_LOADED = True
307
308
309def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
310    lazy_load_decompositions()
311    _check_out_dims_is_int_or_int_pytree(out_dims, func)
312    batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
313        in_dims, args, func
314    )
315
316    if chunk_size is not None:
317        chunks_flat_args = _get_chunked_inputs(
318            flat_args, flat_in_dims, batch_size, chunk_size
319        )
320        return _chunked_vmap(
321            func,
322            flat_in_dims,
323            chunks_flat_args,
324            args_spec,
325            out_dims,
326            randomness,
327            **kwargs,
328        )
329
330    # If chunk_size is not specified.
331    return _flat_vmap(
332        func,
333        batch_size,
334        flat_in_dims,
335        flat_args,
336        args_spec,
337        out_dims,
338        randomness,
339        **kwargs,
340    )
341
342
343def get_chunk_sizes(total_elems, chunk_size):
344    n_chunks = n_chunks = total_elems // chunk_size
345    chunk_sizes = [chunk_size] * n_chunks
346    # remainder chunk
347    remainder = total_elems % chunk_size
348    if remainder != 0:
349        chunk_sizes.append(remainder)
350    return chunk_sizes
351
352
353def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
354    split_idxs = (batch_size,)
355    if chunk_size is not None:
356        chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
357        split_idxs = tuple(itertools.accumulate(chunk_sizes))
358
359    flat_args_chunks = tuple(
360        t.tensor_split(split_idxs, dim=in_dim)
361        if in_dim is not None
362        else [
363            t,
364        ]
365        * len(split_idxs)
366        for t, in_dim in zip(flat_args, flat_in_dims)
367    )
368
369    # transpose chunk dim and flatten structure
370    # chunks_flat_args is a list of flatten args
371    chunks_flat_args = zip(*flat_args_chunks)
372    return chunks_flat_args
373
374
375def _flatten_chunks_output(chunks_output_):
376    # chunks_output is a list of chunked outputs
377    # flatten chunked outputs:
378    flat_chunks_output = []
379    arg_spec = None
380    for output in chunks_output_:
381        flat_output, arg_specs = tree_flatten(output)
382        flat_chunks_output.append(flat_output)
383        if arg_spec is None:
384            arg_spec = arg_specs
385
386    # transpose chunk dim and flatten structure
387    # flat_output_chunks is flat list of chunks
388    flat_output_chunks = list(zip(*flat_chunks_output))
389    return flat_output_chunks, arg_spec
390
391
392def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
393    # concat chunks on out_dim
394    flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
395    assert len(flat_out_dims) == len(flat_output_chunks)
396    flat_output = []
397    for idx, out_dim in enumerate(flat_out_dims):
398        flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
399        # release tensors
400        flat_output_chunks[idx] = None
401
402    return flat_output
403
404
405# Applies vmap on chunked_input and returns concatenated output over the chunks.
406def _chunked_vmap(
407    func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs
408):
409    chunks_output = []
410    rs = torch.get_rng_state() if randomness == "same" else None
411    for flat_args in chunks_flat_args:
412        batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
413
414        # The way we compute split the input in `_get_chunked_inputs`,
415        # we may get a tensor with `0` batch-size. We skip any computation
416        # in that case.
417        # Eg.
418        # >>> chunk_size = 1
419        # >>> batch_size = 6
420        # >>> t = torch.zeros(batch_size, 1)
421        # >>> t.tensor_split([1, 2, 3, 4, 5, 6])
422        # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
423        #  tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
424        if batch_size == 0:
425            continue
426
427        if rs is not None:
428            torch.set_rng_state(rs)
429        chunks_output.append(
430            _flat_vmap(
431                func,
432                batch_size,
433                flat_in_dims,
434                flat_args,
435                args_spec,
436                out_dims,
437                randomness,
438                **kwargs,
439            )
440        )
441
442    flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
443
444    # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
445    # eagerly remove the reference from `chunks_output`.
446    del chunks_output
447
448    # concat chunks on out_dim
449    flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
450
451    # finally unflatten the output
452    return tree_unflatten(flat_output, arg_spec)
453
454
455# Vmap refactored helper functions:
456def _check_randomness_arg(randomness):
457    if randomness not in ["error", "different", "same"]:
458        raise RuntimeError(
459            f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}"
460        )
461
462
463@contextlib.contextmanager
464def vmap_increment_nesting(batch_size, randomness):
465    try:
466        vmap_level = _vmap_increment_nesting(batch_size, randomness)
467        yield vmap_level
468    finally:
469        _vmap_decrement_nesting()
470
471
472def _flat_vmap(
473    func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
474):
475    with vmap_increment_nesting(batch_size, randomness) as vmap_level:
476        batched_inputs = _create_batched_inputs(
477            flat_in_dims, flat_args, vmap_level, args_spec
478        )
479        batched_outputs = func(*batched_inputs, **kwargs)
480        return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
481
482
483# `restore_vmap` is a private helper function. It is vmap but has the following
484# differences:
485# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
486#   out_dims is a pytree of same shape as outputs and contains Optional[int]
487#   specifying where the vmapped dimension, if it exists, is in the corresponding output.
488# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
489#   restore_vmap allows for no inputs to have the vmap dimension
490# - does no validation on outputs (vmap expects only Tensor outputs)
491#   restore_vmap allows for return of arbitrary outputs (not just Tensors)
492#
493# The TL;DR is that restore_vmap is more general than vmap and has a slightly
494# different API. The relaxations are so that we can "pause" vmap in the middle
495# of its execution and then "restore" it later (this is what we do in
496# the generate_vmap_rule=True implementation of autograd.Function).
497#
498# restore_vmap can be technically used in the implementation of vmap, but doing
499# that refactor is a bit technically challenging because:
500# - vmap couples the tensor-wrapping code with error checking
501# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
502#   in python because it overlaps with unwrap_batched
503def restore_vmap(func, in_dims, batch_size, randomness):
504    def inner(*args, **kwargs):
505        with vmap_increment_nesting(batch_size, randomness) as vmap_level:
506            batched_inputs = wrap_batched(args, in_dims, vmap_level)
507            batched_outputs = func(*batched_inputs, **kwargs)
508            return unwrap_batched(batched_outputs, vmap_level)
509
510    return inner
511
512
513def wrap_batched(args, bdims, level):
514    flat_args, spec = tree_flatten(args)
515    flat_bdims = _broadcast_to_and_flatten(bdims, spec)
516    assert flat_bdims is not None
517    result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
518    return result
519
520
521def unwrap_batched(args, level):
522    flat_args, spec = tree_flatten(args)
523    if len(flat_args) == 0:
524        return args, ()
525    result = [
526        torch._C._functorch._unwrap_batched(arg, level)
527        if isinstance(arg, torch.Tensor)
528        else (arg, None)
529        for arg in flat_args
530    ]
531    output, bdims = zip(*result)
532    return tree_unflatten(output, spec), tree_unflatten(bdims, spec)
533