1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport logging 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerclass TestAutogradLogging(LoggingTestCase): 10*da0073e9SAndroid Build Coastguard Worker @make_logging_test(autograd=logging.DEBUG) 11*da0073e9SAndroid Build Coastguard Worker def test_logging(self, records): 12*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, requires_grad=True) 13*da0073e9SAndroid Build Coastguard Worker b = a.mul(2).div(3).sum() 14*da0073e9SAndroid Build Coastguard Worker c = b.clone() 15*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward((b, c)) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(records), 5) 18*da0073e9SAndroid Build Coastguard Worker expected = [ 19*da0073e9SAndroid Build Coastguard Worker "CloneBackward0", 20*da0073e9SAndroid Build Coastguard Worker "SumBackward0", 21*da0073e9SAndroid Build Coastguard Worker "DivBackward0", 22*da0073e9SAndroid Build Coastguard Worker "MulBackward0", 23*da0073e9SAndroid Build Coastguard Worker "AccumulateGrad", 24*da0073e9SAndroid Build Coastguard Worker ] 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker for i, record in enumerate(records): 27*da0073e9SAndroid Build Coastguard Worker self.assertIn(expected[i], record.getMessage()) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 31*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker run_tests() 34