# Owner(s): ["module: autograd"] import contextlib import warnings import numpy as np import torch from torch.library import _scoped_library, Library from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, TestCase, ) @contextlib.contextmanager def autograd_fallback_mode(mode): prev = torch._C._get_autograd_fallback_mode() try: torch._C._set_autograd_fallback_mode(mode) yield finally: torch._C._set_autograd_fallback_mode(prev) class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" def tearDown(self): if hasattr(torch.ops, self.test_ns): delattr(torch.ops, self.test_ns) if hasattr(self, "lib"): del self.lib.m del self.lib def get_op(self, name): return getattr(getattr(torch.ops, self.test_ns), name).default def get_lib(self): lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 self.lib = lib return lib @parametrize("mode", ("nothing", "warn")) def test_no_grad(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") lib.impl("foo", lambda a, b, c: a + b + c, "CPU") op = self.get_op("foo") with warnings.catch_warnings(): warnings.simplefilter("error") with torch.no_grad(): a = torch.randn([], requires_grad=True) b = torch.randn([], requires_grad=True) out = op(a, b, 1) self.assertFalse(out.requires_grad) with warnings.catch_warnings(): warnings.simplefilter("error") a = torch.randn([]) b = torch.randn([]) out = op(a, b, 1) self.assertFalse(out.requires_grad) @parametrize("mode", ("nothing", "warn")) def test_no_autograd_kernel(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") op = self.get_op("foo") def foo_impl(a, b, c): result = a.detach().numpy() + b.detach().numpy() + c return torch.tensor(result) lib.impl("foo", foo_impl, "CPU") # Some inputs requiring grad a = torch.randn([], requires_grad=False) b = torch.randn([], requires_grad=True) out = op(a, b, 1).sum() with self._check_ctx(mode, mode_nothing_raises=True): out.backward() self.assertIsNone(b.grad) def _check_ctx(self, mode, *, mode_nothing_raises=False): if mode == "warn": return self.assertWarnsRegex( UserWarning, "an autograd kernel was not registered" ) assert mode == "nothing" if mode_nothing_raises: return self.assertRaisesRegex(RuntimeError, "does not require grad") return contextlib.nullcontext() @parametrize("mode", ("nothing", "warn")) def test_no_autograd_kernel_inplace(self, mode): with autograd_fallback_mode(mode): # input modified in-place gets returned as output lib = self.get_lib() lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))") op = self.get_op("foo") def foo_impl(x, y): with torch.no_grad(): x.sin_() y.cos_() return x, y lib.impl("foo", foo_impl, "CPU") x = torch.randn(3, requires_grad=True) w = x.clone() v = x.clone() y0 = w[0] y1 = v[1] z0, z1 = op(y0, y1) for tensor in [w, v, z0, z1, y0, y1]: with self._check_ctx(mode): tensor.sum().backward(retain_graph=True) # no outputs: we don't do anything. Maybe we should in the future. # This is not a common failure mode. lib.define("bar(Tensor(a!) self) -> ()") op = self.get_op("bar") def bar_impl(x): with torch.no_grad(): x.sin_() lib.impl("bar", bar_impl, "CPU") with warnings.catch_warnings(): warnings.simplefilter("error") x = torch.randn([], requires_grad=True) y = x.clone() z = op(y) y.backward() self.assertEqual(x.grad, torch.ones_like(x)) @parametrize("mode", ("nothing", "warn")) def test_cpu_return_self(self, mode): with autograd_fallback_mode(mode): # To be clear, none of these situations are OK and will lead # to other problems down the line. We're testing them because # it is fairly common to actually do these things. with _scoped_library(self.test_ns, "FRAGMENT") as lib: lib.define("foo(Tensor self) -> Tensor") lib.impl("foo", lambda x: x, "CPU") op = self.get_op("foo") x = torch.randn(3, requires_grad=True) y = op(x).sum() with self._check_ctx(mode): y.backward() self.assertEqual(x.grad, torch.ones_like(x)) lib.define("bar(Tensor(a!) self) -> Tensor(a!)") lib.impl("bar", lambda x: x, "CPU") op = self.get_op("bar") x = torch.randn(3, requires_grad=True) y = op(x).sum() with self._check_ctx(mode): y.backward() self.assertEqual(x.grad, torch.ones_like(x)) @parametrize("mode", ("nothing", "warn")) def test_composite_registered_to_cpu(self, mode): with autograd_fallback_mode(mode): with _scoped_library(self.test_ns, "FRAGMENT") as lib: lib.define("foo(Tensor self) -> Tensor") lib.impl("foo", lambda x: x.sin().sum(), "CPU") op = self.get_op("foo") x = torch.randn(3, requires_grad=True) y = op(x) with self._check_ctx(mode): y.backward() self.assertEqual(x.grad, x.cos()) @parametrize("mode", ("nothing", "warn")) def test_autograd_function_registered_to_cpu(self, mode): with autograd_fallback_mode(mode): with _scoped_library(self.test_ns, "FRAGMENT") as lib: lib.define("foo(Tensor self) -> Tensor") class NumpySin(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return torch.tensor(np.sin(x.cpu().numpy())) @staticmethod def backward(ctx, gx): (x,) = ctx.saved_tensors return gx * x.cos() lib.impl("foo", NumpySin.apply, "CPU") op = self.get_op("foo") x = torch.randn(3, requires_grad=True) y = op(x).sum() with self._check_ctx(mode): y.backward() self.assertEqual(x.grad, x.cos()) @parametrize("mode", ("nothing", "warn")) def test_inplace_autograd_function_registered_to_cpu(self, mode): with autograd_fallback_mode(mode): with _scoped_library(self.test_ns, "FRAGMENT") as lib: lib.define("foo(Tensor(a!) self) -> Tensor(a!)") class NumpySin_(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x.clone()) x_np = x.detach().numpy() np.sin(x_np, out=x_np) ctx.mark_dirty(x) return x @staticmethod def backward(ctx, gx): (x,) = ctx.saved_tensors return gx * x.cos() lib.impl("foo", NumpySin_.apply, "CPU") op = self.get_op("foo") x = torch.randn(3, requires_grad=True) z = x.clone() w = z[0] y = op(w) expected = torch.zeros_like(x) expected[0] = x[0].cos() with self._check_ctx(mode): (gx,) = torch.autograd.grad( y, x, torch.ones_like(y), retain_graph=True ) self.assertEqual(gx, expected) expected = torch.ones_like(x) expected[0] = x[0].cos() with self._check_ctx(mode): (gx,) = torch.autograd.grad(z, x, torch.ones_like(z)) self.assertEqual(gx, expected) @parametrize("mode", ("nothing", "warn")) def test_inplace_on_tensor_that_does_not_require_grad(self, mode): # We don't do anything special (that is, we don't rebase history). # See NOTE [autograd fallback and in-place operations] for why with autograd_fallback_mode(mode): with _scoped_library(self.test_ns, "FRAGMENT") as lib: # Correct usage of (a!) lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)") def foo_impl(x, y): x_d = x.detach() y = y.detach() x_d.add_(y) return x lib.impl("foo", foo_impl, "CPU") foo = self.get_op("foo") # Incorrect usage of (a!): user doesn't return tensor as-is lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)") def bar_impl(x, y): x_d = x.detach() y = y.detach() x_d.add_(y) return x_d.clone() lib.impl("bar", bar_impl, "CPU") bar = self.get_op("bar") # User mutated input tensor but didn't return it. lib.define("baz(Tensor(a!) self, Tensor other) -> ()") def baz_impl(x, y): x_d = x.detach() y = y.detach() x_d.add_(y) lib.impl("baz", baz_impl, "CPU") baz = self.get_op("baz") # Test in-place on non-view for op in (foo, bar, baz): x = torch.randn(3) y = torch.randn(3, requires_grad=True) with self.assertRaisesRegex(RuntimeError, "does not require grad"): z = x.clone() op(z, y) torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True) # Test in-place on view for op in (foo, bar, baz): x = torch.randn(3) y = torch.randn(3, requires_grad=True) with self.assertRaisesRegex(RuntimeError, "does not require grad"): z = x[:] op(z, y) torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True) @parametrize("mode", ("nothing", "warn")) def test_post_autograd_returns_leaf(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a) -> (Tensor, Tensor)") op = self.get_op("foo") lib.impl( "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU" ) x = torch.randn(3, requires_grad=True) y, z = op(x) with self._check_ctx(mode): z.sum().backward() @parametrize("mode", ("nothing", "warn")) def test_undefined_inputs_outputs(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") op = self.get_op("foo") def foo_impl(a, b): return None, b.clone() lib.impl("foo", foo_impl, "CPU") x = torch.randn(3, requires_grad=True) # NB: PyTorch dispatcher treats "None" as undefined Tensor. y, z = op(None, x) with self._check_ctx(mode): z.sum().backward() @parametrize("mode", ("nothing", "warn")) def test_undefined_grads(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") op = self.get_op("foo") def foo_impl(a, b): return a.sin(), b.cos() lib.impl("foo", foo_impl, "CPU") x = torch.randn(3, requires_grad=True) y = torch.randn(3) w, z = op(x, y) w = torch._C._functions.UndefinedGrad()(w) z = torch._C._functions.UndefinedGrad()(z) with self._check_ctx(mode): (z + w).sum().backward() @parametrize("mode", ("nothing", "warn")) def test_base_does_not_require_grad(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor(a!) x) -> Tensor(a!)") op = self.get_op("foo") def foo_impl(a): with torch.no_grad(): return a.zero_() lib.impl("foo", foo_impl, "CPU") x = torch.randn(3) y = x[:] y.requires_grad_() w = y[:] self.assertTrue(w._base is x) # Hook should be registered on w, but not w._base op(w) with self._check_ctx(mode): w.sum().backward() @parametrize("mode", ("nothing", "warn")) def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)") op = self.get_op("foo") def foo_impl(a, b): with torch.no_grad(): x = a.clone() z = b.clone() y = a * b return x, y, z lib.impl("foo", foo_impl, "CPU") a = torch.randn(3, requires_grad=True) b = torch.randn(3, requires_grad=True) x, y, z = op(a, b) with self._check_ctx(mode, mode_nothing_raises=True): torch.autograd.grad( x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True ) with self._check_ctx(mode, mode_nothing_raises=False): torch.autograd.grad( y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True ) with self._check_ctx(mode, mode_nothing_raises=True): torch.autograd.grad( z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True ) @parametrize("mode", ("nothing", "warn")) def test_supports_tensor_lists(self, mode): with autograd_fallback_mode(mode): lib = self.get_lib() lib.define("foo(Tensor[] a) -> Tensor[]") op = self.get_op("foo") def foo_impl(a): x, y, z = a with torch.no_grad(): return x + y + z, x * y * z lib.impl("foo", foo_impl, "CPU") x = torch.randn(3, requires_grad=True) y = torch.randn(1, requires_grad=True) z = torch.randn(2, 1, requires_grad=True) a, b = op([x, y, z]) with self._check_ctx(mode, mode_nothing_raises=True): torch.autograd.grad( a, (x, y, z), torch.ones_like(a), allow_unused=True, retain_graph=True, ) with self._check_ctx(mode, mode_nothing_raises=True): torch.autograd.grad( b, (x, y, z), torch.ones_like(b), allow_unused=True, retain_graph=True, ) instantiate_parametrized_tests(TestAutogradFallback) if __name__ == "__main__": run_tests()