xref: /aosp_15_r20/external/pytorch/torch/testing/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3
4import torch
5
6
7# Common testing utilities for use in public testing APIs.
8# NB: these should all be importable without optional dependencies
9# (like numpy and expecttest).
10
11
12def wrapper_set_seed(op, *args, **kwargs):
13    """Wrapper to set seed manually for some functions like dropout
14    See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
15    """
16    with freeze_rng_state():
17        torch.manual_seed(42)
18        output = op(*args, **kwargs)
19
20        if isinstance(output, torch.Tensor) and output.device.type == "lazy":
21            # We need to call mark step inside freeze_rng_state so that numerics
22            # match eager execution
23            torch._lazy.mark_step()  # type: ignore[attr-defined]
24
25        return output
26
27
28@contextlib.contextmanager
29def freeze_rng_state():
30    # no_dispatch needed for test_composite_compliance
31    # Some OpInfos use freeze_rng_state for rng determinism, but
32    # test_composite_compliance overrides dispatch for all torch functions
33    # which we need to disable to get and set rng state
34    with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
35        rng_state = torch.get_rng_state()
36        if torch.cuda.is_available():
37            cuda_rng_state = torch.cuda.get_rng_state()
38    try:
39        yield
40    finally:
41        # Modes are not happy with torch.cuda.set_rng_state
42        # because it clones the state (which could produce a Tensor Subclass)
43        # and then grabs the new tensor's data pointer in generator.set_state.
44        #
45        # In the long run torch.cuda.set_rng_state should probably be
46        # an operator.
47        #
48        # NB: Mode disable is to avoid running cross-ref tests on thes seeding
49        with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
50            if torch.cuda.is_available():
51                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
52            torch.set_rng_state(rng_state)
53