xref: /aosp_15_r20/external/pytorch/test/autograd/test_logging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport logging
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerclass TestAutogradLogging(LoggingTestCase):
10*da0073e9SAndroid Build Coastguard Worker    @make_logging_test(autograd=logging.DEBUG)
11*da0073e9SAndroid Build Coastguard Worker    def test_logging(self, records):
12*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, requires_grad=True)
13*da0073e9SAndroid Build Coastguard Worker        b = a.mul(2).div(3).sum()
14*da0073e9SAndroid Build Coastguard Worker        c = b.clone()
15*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward((b, c))
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(records), 5)
18*da0073e9SAndroid Build Coastguard Worker        expected = [
19*da0073e9SAndroid Build Coastguard Worker            "CloneBackward0",
20*da0073e9SAndroid Build Coastguard Worker            "SumBackward0",
21*da0073e9SAndroid Build Coastguard Worker            "DivBackward0",
22*da0073e9SAndroid Build Coastguard Worker            "MulBackward0",
23*da0073e9SAndroid Build Coastguard Worker            "AccumulateGrad",
24*da0073e9SAndroid Build Coastguard Worker        ]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker        for i, record in enumerate(records):
27*da0073e9SAndroid Build Coastguard Worker            self.assertIn(expected[i], record.getMessage())
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
31*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    run_tests()
34