xref: /aosp_15_r20/external/pytorch/test/autograd/test_logging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: autograd"]
2
3import logging
4
5import torch
6from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
7
8
9class TestAutogradLogging(LoggingTestCase):
10    @make_logging_test(autograd=logging.DEBUG)
11    def test_logging(self, records):
12        a = torch.rand(10, requires_grad=True)
13        b = a.mul(2).div(3).sum()
14        c = b.clone()
15        torch.autograd.backward((b, c))
16
17        self.assertEqual(len(records), 5)
18        expected = [
19            "CloneBackward0",
20            "SumBackward0",
21            "DivBackward0",
22            "MulBackward0",
23            "AccumulateGrad",
24        ]
25
26        for i, record in enumerate(records):
27            self.assertIn(expected[i], record.getMessage())
28
29
30if __name__ == "__main__":
31    from torch._dynamo.test_case import run_tests
32
33    run_tests()
34