1# mypy: allow-untyped-defs 2import contextlib 3 4import torch 5 6 7# Common testing utilities for use in public testing APIs. 8# NB: these should all be importable without optional dependencies 9# (like numpy and expecttest). 10 11 12def wrapper_set_seed(op, *args, **kwargs): 13 """Wrapper to set seed manually for some functions like dropout 14 See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. 15 """ 16 with freeze_rng_state(): 17 torch.manual_seed(42) 18 output = op(*args, **kwargs) 19 20 if isinstance(output, torch.Tensor) and output.device.type == "lazy": 21 # We need to call mark step inside freeze_rng_state so that numerics 22 # match eager execution 23 torch._lazy.mark_step() # type: ignore[attr-defined] 24 25 return output 26 27 28@contextlib.contextmanager 29def freeze_rng_state(): 30 # no_dispatch needed for test_composite_compliance 31 # Some OpInfos use freeze_rng_state for rng determinism, but 32 # test_composite_compliance overrides dispatch for all torch functions 33 # which we need to disable to get and set rng state 34 with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): 35 rng_state = torch.get_rng_state() 36 if torch.cuda.is_available(): 37 cuda_rng_state = torch.cuda.get_rng_state() 38 try: 39 yield 40 finally: 41 # Modes are not happy with torch.cuda.set_rng_state 42 # because it clones the state (which could produce a Tensor Subclass) 43 # and then grabs the new tensor's data pointer in generator.set_state. 44 # 45 # In the long run torch.cuda.set_rng_state should probably be 46 # an operator. 47 # 48 # NB: Mode disable is to avoid running cross-ref tests on thes seeding 49 with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): 50 if torch.cuda.is_available(): 51 torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] 52 torch.set_rng_state(rng_state) 53