xref: /aosp_15_r20/external/pytorch/torch/cpu/amp/autocast_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any
3from typing_extensions import deprecated
4
5import torch
6
7
8__all__ = ["autocast"]
9
10
11class autocast(torch.amp.autocast_mode.autocast):
12    r"""
13    See :class:`torch.autocast`.
14    ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead.
15    """
16
17    @deprecated(
18        "`torch.cpu.amp.autocast(args...)` is deprecated. "
19        "Please use `torch.amp.autocast('cpu', args...)` instead.",
20        category=FutureWarning,
21    )
22    def __init__(
23        self,
24        enabled: bool = True,
25        dtype: torch.dtype = torch.bfloat16,
26        cache_enabled: bool = True,
27    ):
28        if torch._jit_internal.is_scripting():
29            self._enabled = enabled
30            self.device = "cpu"
31            self.fast_dtype = dtype
32            return
33        super().__init__(
34            "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
35        )
36
37    def __enter__(self):
38        if torch._jit_internal.is_scripting():
39            return self
40        return super().__enter__()
41
42    # TODO: discuss a unified TorchScript-friendly API for autocast
43    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
44        if torch._jit_internal.is_scripting():
45            return
46        return super().__exit__(exc_type, exc_val, exc_tb)
47
48    def __call__(self, func):
49        if torch._jit_internal.is_scripting():
50            return func
51        return super().__call__(func)
52