1# mypy: allow-untyped-defs 2from typing import Any 3from typing_extensions import deprecated 4 5import torch 6 7 8__all__ = ["autocast"] 9 10 11class autocast(torch.amp.autocast_mode.autocast): 12 r""" 13 See :class:`torch.autocast`. 14 ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. 15 """ 16 17 @deprecated( 18 "`torch.cpu.amp.autocast(args...)` is deprecated. " 19 "Please use `torch.amp.autocast('cpu', args...)` instead.", 20 category=FutureWarning, 21 ) 22 def __init__( 23 self, 24 enabled: bool = True, 25 dtype: torch.dtype = torch.bfloat16, 26 cache_enabled: bool = True, 27 ): 28 if torch._jit_internal.is_scripting(): 29 self._enabled = enabled 30 self.device = "cpu" 31 self.fast_dtype = dtype 32 return 33 super().__init__( 34 "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled 35 ) 36 37 def __enter__(self): 38 if torch._jit_internal.is_scripting(): 39 return self 40 return super().__enter__() 41 42 # TODO: discuss a unified TorchScript-friendly API for autocast 43 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] 44 if torch._jit_internal.is_scripting(): 45 return 46 return super().__exit__(exc_type, exc_val, exc_tb) 47 48 def __call__(self, func): 49 if torch._jit_internal.is_scripting(): 50 return func 51 return super().__call__(func) 52