# Owner(s): ["module: unknown"] import unittest import torch from torch.testing._internal.autocast_test_lists import ( AutocastCPUTestLists, TestAutocast, ) from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, skipIfTorchDynamo, TestCase, ) from torch.utils._python_dispatch import TorchDispatchMode class TestAutocastCPU(TestAutocast): def setUp(self): super().setUp() self.autocast_lists = AutocastCPUTestLists(torch.device("cpu")) def tearDown(self): del self.autocast_lists super().tearDown() @skipIfTorchDynamo() def test_autocast_torch_expect_builtin_promote(self): for ( op, args1, args2, out_type, ) in self.autocast_lists.torch_expect_builtin_promote: self._run_autocast_outofplace( op, args1, torch.float32, device="cpu", out_type=out_type ) self._run_autocast_outofplace( op, args2, torch.float32, device="cpu", out_type=out_type, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_methods_expect_builtin_promote(self): for ( op, args1, args2, out_type, ) in self.autocast_lists.methods_expect_builtin_promote: self._run_autocast_outofplace( op, args1, torch.float32, device="cpu", module=None, out_type=out_type ) self._run_autocast_outofplace( op, args2, torch.float32, device="cpu", module=None, out_type=out_type, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_torch_16(self): for op_with_args in self.autocast_lists.torch_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float16, device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_nn_16(self): for op_with_args in self.autocast_lists.nn_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.bfloat16, device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float16, device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_torch_fp32(self): for op_with_args in self.autocast_lists.torch_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_nn_fp32(self): for op_with_args in self.autocast_lists.nn_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.float32, device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float32, device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @skipIfTorchDynamo() def test_autocast_torch_need_autocast_promote(self): for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: self._run_autocast_outofplace(op, args1, torch.float32, device="cpu") self._run_autocast_outofplace( op, args2, torch.float32, device="cpu", amp_dtype=torch.float16 ) @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def test_autocast_rnn(self): if ( torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported() ): x = torch.randn(1, 2, 1) hx = torch.randn(2, 2, 1) cx = torch.randn(2, 2, 1) m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16) # Raise ValueError when autocast is not enabled with self.assertRaisesRegex(ValueError, "input must have the type"): m(x, (hx, cx)) # Should be able to run the below case with autocast with torch.amp.autocast(device_type="cpu"): m(x, (hx, cx)) def test_autocast_disabled_with_fp32_dtype(self): with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False): _ = torch.ones(10) def test_generic_autocast(self): for op_with_args in self.autocast_lists.torch_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) with torch.amp.autocast(device_type="cpu"): generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) with torch.amp.autocast(device_type="cpu"): cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) self.assertEqual(generic_autocast_output, cpu_autocast_output) def test_cpu_autocast_deprecated_warning(self): with self.assertWarnsRegex( FutureWarning, r"`torch.cpu.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cpu', args...\)` instead.", ): with torch.cpu.amp.autocast(): _ = torch.ones(10) class CustomLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, w_t): ctx.save_for_backward(x, w_t) return torch.nn.functional.linear(x, w_t) @staticmethod def backward(ctx, grad_output): x, w_t = ctx.saved_tensors with torch.autocast(device_type="cuda"): dL_dX = torch.matmul(grad_output, w_t) dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) return dL_dX, dL_dW class WeightDTypeCastCounterMode(TorchDispatchMode): def __init__(self, weight): super().__init__() self.dtype_cast_counter = 0 self.weight = weight def __torch_dispatch__(self, func, types, args=(), kwargs=None): if ( func is torch.ops.aten._to_copy.default and args[0] is self.weight and kwargs["dtype"] is torch.float16 ): self.dtype_cast_counter += 1 return func(*args, **kwargs) def __enter__(self): self.old_clear_cache = torch.clear_autocast_cache torch.clear_autocast_cache = lambda: None return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): torch.clear_autocast_cache = self.old_clear_cache return super().__exit__(exc_type, exc_val, exc_tb) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestAutocastGPU(TestCase): def test_cast_cache_is_global(self): """ Verifies that the autocast cache is global. This is done by mocking out cache clearing at the end of the forward pass, running forward+backward with an explicit call to autocast in the backward, and verifying that the weight only get cast to float16 once. """ data = torch.randn(2, 3).cuda() weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) with WeightDTypeCastCounterMode(weight) as mode: with torch.autocast(device_type="cuda"): output = CustomLinear.apply(data, weight) s = output.sum() s.backward() self.assertEqual(mode.dtype_cast_counter, 1) def test_cache_disabled(self): data = torch.randn(2, 3).cuda() weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) try: torch._C._set_cached_tensors_enabled(True) torch._C._add_cached_tensor(weight) with WeightDTypeCastCounterMode(weight) as mode: with torch.autocast(device_type="cuda"): output = CustomLinear.apply(data, weight) s = output.sum() s.backward() # we should not have cached the conversion of the weight self.assertEqual(mode.dtype_cast_counter, 2) finally: torch._C._set_cached_tensors_enabled(False) # index_put under AMP follows a cast policy called "promote", # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 # That means: # (1) double precision is ignored, # (2) if any argument is float, then all arguments are promoted to float, # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. # For more info see https://github.com/pytorch/pytorch/issues/132715. def test_autocast_prioritize(self): device = "cuda" dtype = torch.bfloat16 with torch.autocast(device_type=device, enabled=True, dtype=dtype): t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) index = torch.randint( low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device ) val = torch.randn(1, dtype=dtype, device=device) res = torch.index_put(t, [index], val) loss = res.mean() loss.backward() @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") class TestAutocastMPS(TestCase): def test_cast_cache_is_global(self): class CustomLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, w_t): ctx.save_for_backward(x, w_t) return torch.nn.functional.linear(x, w_t) @staticmethod def backward(ctx, grad_output): x, w_t = ctx.saved_tensors with torch.autocast(device_type="mps"): dL_dX = torch.matmul(grad_output, w_t) dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) return dL_dX, dL_dW data = torch.randn(2, 3).to("mps") weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) weight_dtype_cast_counter = 0 class WeightDTypeCastCounterMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if ( func is torch.ops.aten._to_copy.default and args[0] is weight and kwargs["dtype"] is torch.float16 ): nonlocal weight_dtype_cast_counter weight_dtype_cast_counter += 1 return func(*args, **kwargs) def __enter__(self): # self.old_clear_cache = torch.clear_autocast_cache # torch.clear_autocast_cache = lambda: None return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): # torch.clear_autocast_cache = self.old_clear_cache return super().__exit__(exc_type, exc_val, exc_tb) with WeightDTypeCastCounterMode(): with torch.autocast(device_type="mps"): output = CustomLinear.apply(data, weight) s = output.sum() s.backward() self.assertEqual(weight_dtype_cast_counter, 2) class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") self.assertEqual(gpu_fast_dtype, torch.half) self.assertEqual(cpu_fast_dtype, torch.bfloat16) def test_invalid_device(self): dev = "not a real device" msg = f"Invalid device string: '{dev}'" with self.assertRaisesRegex(RuntimeError, msg): with torch.autocast(device_type=dev): _ = torch.tensor(1) with self.assertRaisesRegex(RuntimeError, msg): assert torch.amp.is_autocast_available(device_type=dev) def test_non_string_device(self): """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string""" dev = torch.device("cpu") msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`" with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg): torch.autocast(device_type=dev) if __name__ == "__main__": run_tests()