xref: /aosp_15_r20/external/pytorch/test/test_ops_gradients.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
7*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
8*da0073e9SAndroid Build Coastguard Worker    OpDTypes,
9*da0073e9SAndroid Build Coastguard Worker    ops,
10*da0073e9SAndroid Build Coastguard Worker)
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
13*da0073e9SAndroid Build Coastguard Worker    run_tests,
14*da0073e9SAndroid Build Coastguard Worker    TestCase,
15*da0073e9SAndroid Build Coastguard Worker    TestGradients,
16*da0073e9SAndroid Build Coastguard Worker    unMarkDynamoStrictTest,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.custom_op_db import custom_op_db
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.hop_db import hop_db
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# gradcheck requires double precision
23*da0073e9SAndroid Build Coastguard Worker_gradcheck_ops = partial(
24*da0073e9SAndroid Build Coastguard Worker    ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
29*da0073e9SAndroid Build Coastguard Workerclass TestBwdGradients(TestGradients):
30*da0073e9SAndroid Build Coastguard Worker    # Tests that gradients are computed correctly
31*da0073e9SAndroid Build Coastguard Worker    @_gradcheck_ops(op_db + hop_db + custom_op_db)
32*da0073e9SAndroid Build Coastguard Worker    def test_fn_grad(self, device, dtype, op):
33*da0073e9SAndroid Build Coastguard Worker        # This is verified by test_dtypes in test_ops.py
34*da0073e9SAndroid Build Coastguard Worker        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
35*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
36*da0073e9SAndroid Build Coastguard Worker        else:
37*da0073e9SAndroid Build Coastguard Worker            self._grad_test_helper(device, dtype, op, op.get_op())
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    # Method grad (and gradgrad, see below) tests are disabled since they're
40*da0073e9SAndroid Build Coastguard Worker    #   costly and redundant with function grad (and gradgad) tests
41*da0073e9SAndroid Build Coastguard Worker    # @_gradcheck_ops(op_db)
42*da0073e9SAndroid Build Coastguard Worker    # def test_method_grad(self, device, dtype, op):
43*da0073e9SAndroid Build Coastguard Worker    #     self._skip_helper(op, device, dtype)
44*da0073e9SAndroid Build Coastguard Worker    #     self._grad_test_helper(device, dtype, op, op.get_method())
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    @_gradcheck_ops(op_db + custom_op_db)
47*da0073e9SAndroid Build Coastguard Worker    def test_inplace_grad(self, device, dtype, op):
48*da0073e9SAndroid Build Coastguard Worker        self._skip_helper(op, device, dtype)
49*da0073e9SAndroid Build Coastguard Worker        if not op.inplace_variant:
50*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Op has no inplace variant!")
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        # Verifies an operation doesn't support inplace autograd if it claims not to
53*da0073e9SAndroid Build Coastguard Worker        if not op.supports_inplace_autograd:
54*da0073e9SAndroid Build Coastguard Worker            inplace = self._get_safe_inplace(op.get_inplace())
55*da0073e9SAndroid Build Coastguard Worker            for sample in op.sample_inputs(device, dtype, requires_grad=True):
56*da0073e9SAndroid Build Coastguard Worker                if sample.broadcasts_input:
57*da0073e9SAndroid Build Coastguard Worker                    continue
58*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(Exception):
59*da0073e9SAndroid Build Coastguard Worker                    result = inplace(sample)
60*da0073e9SAndroid Build Coastguard Worker                    result.sum().backward()
61*da0073e9SAndroid Build Coastguard Worker        else:
62*da0073e9SAndroid Build Coastguard Worker            self._grad_test_helper(
63*da0073e9SAndroid Build Coastguard Worker                device, dtype, op, self._get_safe_inplace(op.get_inplace())
64*da0073e9SAndroid Build Coastguard Worker            )
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    # Test that gradients of gradients are computed correctly
67*da0073e9SAndroid Build Coastguard Worker    @_gradcheck_ops(op_db + hop_db + custom_op_db)
68*da0073e9SAndroid Build Coastguard Worker    def test_fn_gradgrad(self, device, dtype, op):
69*da0073e9SAndroid Build Coastguard Worker        self._skip_helper(op, device, dtype)
70*da0073e9SAndroid Build Coastguard Worker        if not op.supports_gradgrad:
71*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
72*da0073e9SAndroid Build Coastguard Worker                "Op claims it doesn't support gradgrad. This is not verified."
73*da0073e9SAndroid Build Coastguard Worker            )
74*da0073e9SAndroid Build Coastguard Worker        else:
75*da0073e9SAndroid Build Coastguard Worker            self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    # Test that gradients of gradients are properly raising
78*da0073e9SAndroid Build Coastguard Worker    @_gradcheck_ops(op_db + custom_op_db)
79*da0073e9SAndroid Build Coastguard Worker    def test_fn_fail_gradgrad(self, device, dtype, op):
80*da0073e9SAndroid Build Coastguard Worker        self._skip_helper(op, device, dtype)
81*da0073e9SAndroid Build Coastguard Worker        if op.supports_gradgrad:
82*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Operation does support gradgrad")
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        err_msg = r"derivative for .* is not implemented"
85*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
86*da0073e9SAndroid Build Coastguard Worker            self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    # Method gradgrad (and grad, see above) tests are disabled since they're
89*da0073e9SAndroid Build Coastguard Worker    #   costly and redundant with function gradgrad (and grad) tests
90*da0073e9SAndroid Build Coastguard Worker    # @_gradcheck_ops(op_db)
91*da0073e9SAndroid Build Coastguard Worker    # def test_method_gradgrad(self, device, dtype, op):
92*da0073e9SAndroid Build Coastguard Worker    #     self._skip_helper(op, device, dtype)
93*da0073e9SAndroid Build Coastguard Worker    #     self._gradgrad_test_helper(device, dtype, op, op.get_method())
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    @_gradcheck_ops(op_db)
96*da0073e9SAndroid Build Coastguard Worker    def test_inplace_gradgrad(self, device, dtype, op):
97*da0073e9SAndroid Build Coastguard Worker        self._skip_helper(op, device, dtype)
98*da0073e9SAndroid Build Coastguard Worker        if not op.inplace_variant or not op.supports_inplace_autograd:
99*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Operation does not support inplace autograd.")
100*da0073e9SAndroid Build Coastguard Worker        self._check_helper(
101*da0073e9SAndroid Build Coastguard Worker            device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad"
102*da0073e9SAndroid Build Coastguard Worker        )
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBwdGradients, globals())
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
108*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
109*da0073e9SAndroid Build Coastguard Worker    run_tests()
110