xref: /aosp_15_r20/external/pytorch/test/test_cuda_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerimport unittest.mock
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.cuda._gpu_trace as gpu_trace
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
13*da0073e9SAndroid Build Coastguard Worker# and make sure the CUDA streams are initialized for each test that uses them.
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA:
16*da0073e9SAndroid Build Coastguard Worker    print("CUDA not available, skipping tests", file=sys.stderr)
17*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest
21*da0073e9SAndroid Build Coastguard Workerclass TestCudaTrace(TestCase):
22*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
23*da0073e9SAndroid Build Coastguard Worker        torch._C._activate_gpu_trace()
24*da0073e9SAndroid Build Coastguard Worker        self.mock = unittest.mock.MagicMock()
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    def test_event_creation_callback(self):
27*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_creation(self.mock)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event()
30*da0073e9SAndroid Build Coastguard Worker        event.record()
31*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event._as_parameter_.value)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_event_deletion_callback(self):
34*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_deletion(self.mock)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event()
37*da0073e9SAndroid Build Coastguard Worker        event.record()
38*da0073e9SAndroid Build Coastguard Worker        event_id = event._as_parameter_.value
39*da0073e9SAndroid Build Coastguard Worker        del event
40*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event_id)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def test_event_record_callback(self):
43*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_record(self.mock)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event()
46*da0073e9SAndroid Build Coastguard Worker        event.record()
47*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(
48*da0073e9SAndroid Build Coastguard Worker            event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
49*da0073e9SAndroid Build Coastguard Worker        )
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def test_event_wait_callback(self):
52*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_wait(self.mock)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event()
55*da0073e9SAndroid Build Coastguard Worker        event.record()
56*da0073e9SAndroid Build Coastguard Worker        event.wait()
57*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(
58*da0073e9SAndroid Build Coastguard Worker            event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
59*da0073e9SAndroid Build Coastguard Worker        )
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def test_memory_allocation_callback(self):
62*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_memory_allocation(self.mock)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(10, 4, device="cuda")
65*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(tensor.data_ptr())
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def test_memory_deallocation_callback(self):
68*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_memory_deallocation(self.mock)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(3, 8, device="cuda")
71*da0073e9SAndroid Build Coastguard Worker        data_ptr = tensor.data_ptr()
72*da0073e9SAndroid Build Coastguard Worker        del tensor
73*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(data_ptr)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def test_stream_creation_callback(self):
76*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_stream_creation(self.mock)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        # see Note [HIP Lazy Streams]
79*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip:
80*da0073e9SAndroid Build Coastguard Worker            user_stream = torch.cuda.Stream()
81*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.stream(user_stream):
82*da0073e9SAndroid Build Coastguard Worker                tensor = torch.ones(5, device="cuda")
83*da0073e9SAndroid Build Coastguard Worker        else:
84*da0073e9SAndroid Build Coastguard Worker            torch.cuda.Stream()
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called()
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def test_device_synchronization_callback(self):
89*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_device_synchronization(self.mock)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
92*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called()
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    def test_stream_synchronization_callback(self):
95*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_stream_synchronization(self.mock)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
98*da0073e9SAndroid Build Coastguard Worker        stream.synchronize()
99*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(stream.cuda_stream)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_event_synchronization_callback(self):
102*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_event_synchronization(self.mock)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        event = torch.cuda.Event()
105*da0073e9SAndroid Build Coastguard Worker        event.record()
106*da0073e9SAndroid Build Coastguard Worker        event.synchronize()
107*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(event._as_parameter_.value)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def test_memcpy_synchronization(self):
110*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_stream_synchronization(self.mock)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(5, device="cuda")
113*da0073e9SAndroid Build Coastguard Worker        tensor.nonzero()
114*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(torch.cuda.default_stream().cuda_stream)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_all_trace_callbacks_called(self):
117*da0073e9SAndroid Build Coastguard Worker        other = unittest.mock.MagicMock()
118*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_memory_allocation(self.mock)
119*da0073e9SAndroid Build Coastguard Worker        gpu_trace.register_callback_for_memory_allocation(other)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(10, 4, device="cuda")
122*da0073e9SAndroid Build Coastguard Worker        self.mock.assert_called_once_with(tensor.data_ptr())
123*da0073e9SAndroid Build Coastguard Worker        other.assert_called_once_with(tensor.data_ptr())
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
127*da0073e9SAndroid Build Coastguard Worker    run_tests()
128